From 7c080dd3c5b794f31906aff92f4cb829cde4986a Mon Sep 17 00:00:00 2001 From: mikaylagawarecki Date: Tue, 31 Mar 2026 13:21:13 -0400 Subject: [PATCH] [4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503) Signed-off-by: Mikayla Gawarecki --- CMakeLists.txt | 178 +++++----- csrc/cuda_vec_utils.cuh | 4 +- csrc/cutlass_extensions/cute_utils.cuh | 1 - .../broadcast_load_epilogue_array_c3x.hpp | 40 +-- .../epilogue/broadcast_load_epilogue_c3x.hpp | 40 +-- csrc/cutlass_extensions/torch_utils.hpp | 87 +++-- csrc/libtorch_stable/dispatch_utils.h | 9 + csrc/libtorch_stable/ops.h | 50 +++ .../cutlass_w4a8/get_group_starts.cuh | 61 ++-- .../cutlass_w4a8/w4a8_grouped_mm_entry.cu | 146 ++++---- .../cutlass_w4a8/w4a8_mm_entry.cu | 116 ++++--- .../quantization/cutlass_w4a8/w4a8_utils.cu | 0 .../quantization/cutlass_w4a8/w4a8_utils.cuh | 0 .../activation_nvfp4_quant_fusion_kernels.cu | 34 +- .../fp4/nvfp4_blockwise_moe_kernel.cu | 325 +++++++++++------- .../quantization/fp4/nvfp4_experts_quant.cu | 106 +++--- .../quantization/fp4/nvfp4_quant_entry.cu | 172 +++++++++ .../quantization/fp4/nvfp4_quant_kernels.cu | 70 ++-- .../quantization/fp4/nvfp4_scaled_mm_entry.cu | 48 +-- .../fp4/nvfp4_scaled_mm_kernels.cu | 153 +++++---- .../fp4/nvfp4_scaled_mm_sm120_kernels.cu | 161 ++++----- .../quantization/fp4/nvfp4_utils.cuh | 2 +- csrc/libtorch_stable/torch_bindings.cpp | 110 ++++++ csrc/libtorch_stable/torch_utils.h | 1 + csrc/ops.h | 44 --- csrc/quantization/fp4/nvfp4_quant_entry.cu | 163 --------- csrc/torch_bindings.cpp | 100 ------ 27 files changed, 1205 insertions(+), 1016 deletions(-) rename csrc/{ => libtorch_stable}/quantization/cutlass_w4a8/get_group_starts.cuh (63%) rename csrc/{ => libtorch_stable}/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu (79%) rename csrc/{ => libtorch_stable}/quantization/cutlass_w4a8/w4a8_mm_entry.cu (83%) rename csrc/{ => libtorch_stable}/quantization/cutlass_w4a8/w4a8_utils.cu (100%) rename csrc/{ => libtorch_stable}/quantization/cutlass_w4a8/w4a8_utils.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu (86%) rename csrc/{ => libtorch_stable}/quantization/fp4/nvfp4_blockwise_moe_kernel.cu (66%) rename csrc/{ => libtorch_stable}/quantization/fp4/nvfp4_experts_quant.cu (84%) create mode 100644 csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu rename csrc/{ => libtorch_stable}/quantization/fp4/nvfp4_quant_kernels.cu (80%) rename csrc/{ => libtorch_stable}/quantization/fp4/nvfp4_scaled_mm_entry.cu (55%) rename csrc/{ => libtorch_stable}/quantization/fp4/nvfp4_scaled_mm_kernels.cu (65%) rename csrc/{ => libtorch_stable}/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu (58%) rename csrc/{ => libtorch_stable}/quantization/fp4/nvfp4_utils.cuh (99%) delete mode 100644 csrc/quantization/fp4/nvfp4_quant_entry.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index fc50b896d..3db7ff0bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -340,8 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/fp4/nvfp4_quant_entry.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( @@ -489,59 +487,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - # The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require - # CUDA 12.8 or later - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) - set(SRCS - "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" - "csrc/quantization/fp4/nvfp4_experts_quant.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" - "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${FP4_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") - message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") - else() - message(STATUS "Not building NVFP4 as no compatible archs were found.") - # clear FP4_ARCHS - set(FP4_ARCHS) - endif() - - # FP4 Archs and flags - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) - set(SRCS - "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" - "csrc/quantization/fp4/nvfp4_experts_quant.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" - "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${FP4_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") - message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") - else() - message(STATUS "Not building NVFP4 as no compatible archs were found.") - # clear FP4_ARCHS - set(FP4_ARCHS) - endif() - # CUTLASS MLA Archs and flags if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") @@ -681,34 +626,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - # Only build W4A8 kernels if we are building for something compatible with sm90a - cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) - set(SRCS - "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu" - "csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu" - "csrc/quantization/cutlass_w4a8/w4a8_utils.cu" - ) - - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${W4A8_ARCHS}") - - list(APPEND VLLM_EXT_SRC "${SRCS}") - - message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 - AND W4A8_ARCHS) - message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " - "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running w4a16 quantized models on " - "Hopper.") - else() - message(STATUS "Not building W4A8 kernels as no compatible archs " - "found in CUDA target architectures") - endif() - endif() # Hadacore kernels cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") @@ -760,7 +677,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(VLLM_STABLE_EXT_SRC "csrc/libtorch_stable/torch_bindings.cpp" "csrc/cutlass_extensions/common.cpp" - "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu") + "csrc/cuda_utils_kernels.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC @@ -978,6 +898,96 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # + # FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch) + # + + # The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require + # CUDA 12.8 or later + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") + else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) + endif() + + # FP4 Archs and flags + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu" + "csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") + else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) + endif() + + # + # W4A8 kernels (moved from _C to _C_stable_libtorch) + # + + # Only build W4A8 kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu" + "csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu" + "csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu" + ) + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + message(STATUS "Enabling C_stable extension.") define_extension_target( _C_stable_libtorch diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 5e2f51f93..91e181c58 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -3,8 +3,8 @@ #pragma once -#include -#include +#include +#include #include #ifdef USE_ROCM diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index f61fe3ceb..116ce854d 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -1,7 +1,6 @@ #pragma once #include -#include namespace cute { //////////////////////////////////////////////////////////////////// diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp index 5c1d6e3f4..8aa99b3e0 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray { } auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); - Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); - Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); for (int i = 0; i < size(tGS_gRow_flt); ++i) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { @@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray { begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop if (!params.row_broadcast) return; // Do not issue LDS when row is scalar - Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); - Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); - Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem), + cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + cute::Tensor sRow = make_tensor(make_smem_ptr(smem), make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) //// G2S: Gmem to Smem auto tiled_g2s = make_tiled_copy(Copy_Atom{}, @@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray { Stride<_0, _1>>{}, Layout<_1>{}); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); - Tensor tGS_gRow = thr_g2s.partition_S(gRow); - Tensor tGS_sRow = thr_g2s.partition_D(sRow); + cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow); + cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow); //// G2S: Coord auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); - Tensor tGS_cRow = thr_g2s.partition_S(cRow); + cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow); //// S2R: Smem to Reg - Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + cute::Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) return ConsumerStoreCallbacks( tGS_gRow, @@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray { CUTLASS_DEVICE void begin() { - Tensor pred = make_tensor(shape(tCgCol)); + cute::Tensor pred = make_tensor(shape(tCgCol)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(pred); ++i) { pred(i) = get<0>(tCcCol(i)) < m; @@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray { CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { Array frg_col; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { @@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + cute::Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) // Generate an identity tensor matching the shape of the global tensor and // partition the same way, this will be used to generate the predicate // tensor for loading - Tensor cCol = make_identity_tensor(mCol.shape()); - Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor cCol = make_identity_tensor(mCol.shape()); + cute::Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks( diff --git a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 58b1e8ff1..8203d8930 100644 --- a/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast { } auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); - Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); - Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); for (int i = 0; i < size(tGS_gRow_flt); ++i) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { @@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast { begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop if (!params.row_broadcast) return; // Do not issue LDS when row is scalar - Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); - Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem), + cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + cute::Tensor sRow = make_tensor(make_smem_ptr(smem), make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) //// G2S: Gmem to Smem auto tiled_g2s = make_tiled_copy(Copy_Atom{}, @@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast { Stride<_0, _1>>{}, Layout<_1>{}); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); - Tensor tGS_gRow = thr_g2s.partition_S(gRow); - Tensor tGS_sRow = thr_g2s.partition_D(sRow); + cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow); + cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow); //// G2S: Coord auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); - Tensor tGS_cRow = thr_g2s.partition_S(cRow); + cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow); //// S2R: Smem to Reg - Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + cute::Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) return ConsumerStoreCallbacks( tGS_gRow, @@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast { CUTLASS_DEVICE void begin() { - Tensor pred = make_tensor(shape(tCgCol)); + cute::Tensor pred = make_tensor(shape(tCgCol)); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(pred); ++i) { pred(i) = get<0>(tCcCol(i)) < m; @@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast { CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { Array frg_col; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { @@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + cute::Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) // Generate an identity tensor matching the shape of the global tensor and // partition the same way, this will be used to generate the predicate // tensor for loading - Tensor cCol = make_identity_tensor(mCol.shape()); - Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor cCol = make_identity_tensor(mCol.shape()); + cute::Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks( diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index a1ff933cc..45f25ea3d 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -1,6 +1,21 @@ #pragma once -#include +// This header is shared between _C (unstable ABI, used by machete) and +// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION +// is defined only for the stable target, so we switch includes and types +// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor. +#ifdef TORCH_TARGET_VERSION + #include + #include + #include + #include // for STD_TORCH_CHECK +using TorchTensor = torch::stable::Tensor; + #define TORCH_UTILS_CHECK STD_TORCH_CHECK +#else + #include +using TorchTensor = torch::Tensor; + #define TORCH_UTILS_CHECK TORCH_CHECK +#endif #include "cute/layout.hpp" #include "cutlass/layout/matrix.h" @@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { // If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra // strides are set to be 0 or 1. template -static inline auto make_cute_layout(torch::Tensor const& tensor, +static inline auto make_cute_layout(TorchTensor const& tensor, std::string_view name = "tensor") { - TORCH_CHECK(tensor.dim() <= rank(Stride{})); - auto stride = cute::transform_with_idx( - Stride{}, [&](auto const& stride_ele, auto const& idx) { - using StrideEle = std::decay_t; + TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele, + auto const& idx) { + using StrideEle = std::decay_t; - if (idx < tensor.dim()) { - if constexpr (cute::is_static_v) { - TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", - name, ".stride(", idx, ") to be ", StrideEle::value); - return StrideEle{}; - } else { - if (tensor.size(idx) == 1) { - // use 0 stride for dim with size 1, this is easier for - // cute/cutlass to optimize (helps the TMA code flatten dims) - return StrideEle{0}; - } else { - return tensor.stride(idx); - } - } + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; } else { - // Extra strides are assumed to be 0 or 1 - if constexpr (cute::is_static_v) { - static_assert(StrideEle::value == 0 || StrideEle::value == 1); - } - return StrideEle{}; + return tensor.stride(idx); } - }); + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); auto shape = cute::make_shape_from_idx([&](auto const& idx) { if (idx < tensor.dim()) @@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, template static inline auto maybe_make_cute_layout( - std::optional const& tensor, + std::optional const& tensor, std::string_view name = "tensor") { using Layout = decltype(make_cute_layout(*tensor)); @@ -121,12 +136,12 @@ template using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; template <> -struct equivalent_cutlass_type { +struct equivalent_cutlass_type { using type = cutlass::half_t; }; template <> -struct equivalent_cutlass_type { +struct equivalent_cutlass_type { using type = cutlass::bfloat16_t; }; @@ -134,8 +149,8 @@ struct equivalent_cutlass_type { // equivalent_scalar_t (basically inverse of equivalent_cutlass_type) // -// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from -// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +// Return a `torch::headeronly::CppTypeToScalarType` compatible type, i.e. +// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half` template struct equivalent_scalar_type { using type = T; @@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type::type; template <> struct equivalent_scalar_type { - using type = c10::Half; + using type = torch::headeronly::Half; }; template <> struct equivalent_scalar_type { - using type = c10::BFloat16; + using type = torch::headeronly::BFloat16; }; -// get equivalent c10::ScalarType tag from compile time type +// get equivalent torch::headeronly::ScalarType tag from compile time type template -static inline constexpr c10::ScalarType equivalent_scalar_type_v = - c10::CppTypeToScalarType>::value; \ No newline at end of file +static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v = + torch::headeronly::CppTypeToScalarType>::value; diff --git a/csrc/libtorch_stable/dispatch_utils.h b/csrc/libtorch_stable/dispatch_utils.h index 5ebba72b1..ffc2ca031 100644 --- a/csrc/libtorch_stable/dispatch_utils.h +++ b/csrc/libtorch_stable/dispatch_utils.h @@ -49,6 +49,15 @@ THO_DISPATCH_SWITCH(TYPE, NAME, \ VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) +// Half types dispatch (Half + BFloat16) +#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__)) + // Boolean dispatch #define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \ if (expr) { \ diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index 914f033b8..8153102c5 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -84,4 +84,54 @@ void get_cutlass_batched_moe_mm_data( const torch::stable::Tensor& expert_num_tokens, const int64_t num_local_experts, const int64_t padded_m, const int64_t n, const int64_t k); + +// FP4/NVFP4 ops +bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability); + +void cutlass_scaled_fp4_mm(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha); + +void cutlass_fp4_group_mm(torch::stable::Tensor& output, + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + const torch::stable::Tensor& a_blockscale, + const torch::stable::Tensor& b_blockscales, + const torch::stable::Tensor& alphas, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& sf_offsets); + +std::tuple scaled_fp4_quant_func( + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout); + +void scaled_fp4_quant_out(torch::stable::Tensor const& input, + torch::stable::Tensor const& input_scale, + bool is_sf_swizzled_layout, + torch::stable::Tensor& output, + torch::stable::Tensor& output_scale); + +void scaled_fp4_experts_quant( + torch::stable::Tensor& output, torch::stable::Tensor& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts); + +void silu_and_mul_scaled_fp4_experts_quant( + torch::stable::Tensor& output, torch::stable::Tensor& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts); + +void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out, + torch::stable::Tensor& output_block_scale, + torch::stable::Tensor& input, + torch::stable::Tensor& input_global_scale); + #endif diff --git a/csrc/quantization/cutlass_w4a8/get_group_starts.cuh b/csrc/libtorch_stable/quantization/cutlass_w4a8/get_group_starts.cuh similarity index 63% rename from csrc/quantization/cutlass_w4a8/get_group_starts.cuh rename to csrc/libtorch_stable/quantization/cutlass_w4a8/get_group_starts.cuh index fec142d0d..5cda4c975 100644 --- a/csrc/quantization/cutlass_w4a8/get_group_starts.cuh +++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/get_group_starts.cuh @@ -2,10 +2,9 @@ #pragma once #include -#include -#include +#include +#include "libtorch_stable/torch_utils.h" -#include "core/scalar_type.hpp" #include "cutlass/bfloat16.h" #include "cutlass/float8.h" @@ -41,7 +40,7 @@ __global__ void get_group_gemm_starts( } #define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ - else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \ get_group_gemm_starts> \ <<<1, num_experts, 0, stream>>>( \ @@ -66,23 +65,34 @@ __global__ void get_group_gemm_starts( namespace { void run_get_group_gemm_starts( - torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, - torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, - torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, - torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors, - torch::Tensor const& b_tensors, torch::Tensor& out_tensors, - torch::Tensor const& a_scales, torch::Tensor const& b_scales, - torch::Tensor const& b_group_scales, const int64_t b_group_size) { - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32 - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_group_scales.dtype() == - torch::kFloat8_e4m3fn); // the underlying torch type is e4m3 - TORCH_CHECK(out_tensors.dtype() == - torch::kBFloat16); // only support bf16 for now + torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs, + torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs, + torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs, + torch::stable::Tensor& b_group_scales_ptrs, + torch::stable::Tensor const& a_tensors, + torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) { + STD_TORCH_CHECK(a_tensors.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); + STD_TORCH_CHECK( + b_tensors.scalar_type() == + torch::headeronly::ScalarType::Int); // int4 8x packed into int32 + STD_TORCH_CHECK(a_scales.scalar_type() == + torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b_scales.scalar_type() == + torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK( + b_group_scales.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch + // type is e4m3 + STD_TORCH_CHECK( + out_tensors.scalar_type() == + torch::headeronly::ScalarType::BFloat16); // only support bf16 for now // expect int64_t to avoid overflow during offset calculations - TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); + STD_TORCH_CHECK(expert_offsets.scalar_type() == + torch::headeronly::ScalarType::Long); int num_experts = static_cast(expert_offsets.size(0)); // logical k, n @@ -90,15 +100,16 @@ void run_get_group_gemm_starts( int64_t k = a_tensors.size(1); int64_t scale_k = cutlass::ceil_div(k, b_group_size); - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + auto stream = get_current_cuda_stream(a_tensors.get_device_index()); if (false) { } - __CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) - __CALL_GET_STARTS_KERNEL(torch::kFloat16, half) + __CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16, + cutlass::bfloat16_t) + __CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half) else { - TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); } } -} // namespace \ No newline at end of file +} // namespace diff --git a/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu similarity index 79% rename from csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu rename to csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu index 4b425790d..1091d9d12 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu @@ -14,13 +14,12 @@ #include "cutlass/util/mixed_dtype_utils.hpp" // vllm includes -#include -#include -#include +#include +#include +#include "libtorch_stable/torch_utils.h" #include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/common.hpp" -#include "core/registration.h" #include "get_group_starts.cuh" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "w4a8_utils.cuh" @@ -168,31 +167,40 @@ struct W4A8GroupedGemmKernel { static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0, "LayoutB_Reordered size must be divisible by 4 bytes"); - static void grouped_mm( - torch::Tensor& out_tensors, const torch::Tensor& a_tensors, - const torch::Tensor& b_tensors, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, - const int64_t b_group_size, const torch::Tensor& expert_offsets, - const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides, - const torch::Tensor& b_strides, const torch::Tensor& c_strides, - const torch::Tensor& group_scale_strides) { + static void grouped_mm(torch::stable::Tensor& out_tensors, + const torch::stable::Tensor& a_tensors, + const torch::stable::Tensor& b_tensors, + const torch::stable::Tensor& a_scales, + const torch::stable::Tensor& b_scales, + const torch::stable::Tensor& b_group_scales, + const int64_t b_group_size, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& problem_sizes_torch, + const torch::stable::Tensor& a_strides, + const torch::stable::Tensor& b_strides, + const torch::stable::Tensor& c_strides, + const torch::stable::Tensor& group_scale_strides) { auto device = a_tensors.device(); auto device_id = device.index(); - const at::cuda::OptionalCUDAGuard device_guard(device); - auto stream = at::cuda::getCurrentCUDAStream(device_id); + const torch::stable::accelerator::DeviceGuard device_guard(device_id); + auto stream = get_current_cuda_stream(device_id); int num_experts = static_cast(expert_offsets.size(0)); int n = static_cast(b_tensors.size(1)); int k = static_cast(b_tensors.size(2)) * PackFactor; - auto options_int = - torch::TensorOptions().dtype(torch::kInt64).device(device); - torch::Tensor a_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_ptrs = torch::empty(num_experts, options_int); - torch::Tensor out_ptrs = torch::empty(num_experts, options_int); - torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int); + torch::stable::Tensor a_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor b_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor out_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor a_scales_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor b_scales_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); + torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty( + num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device); // get the correct offsets to pass to gemm run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, @@ -247,9 +255,9 @@ struct W4A8GroupedGemmKernel { // Allocate workspace size_t workspace_size = GemmShuffled::get_workspace_size(arguments); - torch::Tensor workspace = - torch::empty(workspace_size, - torch::TensorOptions().dtype(torch::kU8).device(device)); + torch::stable::Tensor workspace = torch::stable::empty( + workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt, + device); // Run GEMM GemmShuffled gemm; @@ -294,14 +302,20 @@ using Kernel_256x128_2x1x1_Coop = using Kernel_128x256_2x1x1_Coop = W4A8GroupedGemmKernel, Shape<_2, _1, _1>, Coop, CoopEpi>; -void mm_dispatch( - torch::Tensor& out_tensors, const torch::Tensor& a_tensors, - const torch::Tensor& b_tensors, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, - const int64_t b_group_size, const torch::Tensor& expert_offsets, - const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, - const torch::Tensor& b_strides, const torch::Tensor& c_strides, - const torch::Tensor& group_scale_strides, const std::string& schedule) { +void mm_dispatch(torch::stable::Tensor& out_tensors, + const torch::stable::Tensor& a_tensors, + const torch::stable::Tensor& b_tensors, + const torch::stable::Tensor& a_scales, + const torch::stable::Tensor& b_scales, + const torch::stable::Tensor& b_group_scales, + const int64_t b_group_size, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& a_strides, + const torch::stable::Tensor& b_strides, + const torch::stable::Tensor& c_strides, + const torch::stable::Tensor& group_scale_strides, + const std::string& schedule) { if (schedule == "Kernel_128x16_1x1x1_Coop") { Kernel_128x16_1x1x1_Coop::grouped_mm( out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, @@ -358,18 +372,23 @@ void mm_dispatch( b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, group_scale_strides); } else { - TORCH_CHECK(false, - "cutlass_w4a8_moe_mm: unknown schedule string: ", schedule); + STD_TORCH_CHECK(false, + "cutlass_w4a8_moe_mm: unknown schedule string: ", schedule); } } -void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors, - const torch::Tensor& b_tensors, const torch::Tensor& a_scales, - const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, - const int64_t b_group_size, const torch::Tensor& expert_offsets, - const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, - const torch::Tensor& b_strides, const torch::Tensor& c_strides, - const torch::Tensor& group_scale_strides, +void mm(torch::stable::Tensor& out_tensors, + const torch::stable::Tensor& a_tensors, + const torch::stable::Tensor& b_tensors, + const torch::stable::Tensor& a_scales, + const torch::stable::Tensor& b_scales, + const torch::stable::Tensor& b_group_scales, const int64_t b_group_size, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& a_strides, + const torch::stable::Tensor& b_strides, + const torch::stable::Tensor& c_strides, + const torch::stable::Tensor& group_scale_strides, std::optional maybe_schedule) { // user has specified a schedule if (maybe_schedule) { @@ -406,26 +425,27 @@ void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors, a_strides, b_strides, c_strides, group_scale_strides, schedule); } -std::tuple encode_and_reorder_int4b( - torch::Tensor const& b_tensors) { - TORCH_CHECK(b_tensors.dtype() == torch::kInt32); - TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k) - TORCH_CHECK(b_tensors.is_contiguous()); - TORCH_CHECK(b_tensors.is_cuda()); +std::tuple +encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) { + STD_TORCH_CHECK(b_tensors.scalar_type() == + torch::headeronly::ScalarType::Int); + STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k) + STD_TORCH_CHECK(b_tensors.is_contiguous()); + STD_TORCH_CHECK(b_tensors.is_cuda()); int n = static_cast(b_tensors.size(1)); int k = static_cast(b_tensors.size(2)) * PackFactor; // logical k // CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0. // These misalignments cause silent OOB unless run under Compute Sanitizer. - TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256"); - TORCH_CHECK(n % 16 == 0, "n must be divisible by 16"); + STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256"); + STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16"); // we will store the layout to an int32 tensor; // this is the number of elements we need per layout constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t); - torch::Tensor b_tensors_packed = torch::empty_like(b_tensors); + torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors); int num_experts = static_cast(b_tensors.size(0)); auto b_ptr = static_cast(b_tensors.const_data_ptr()); @@ -435,7 +455,7 @@ std::tuple encode_and_reorder_int4b( size_t num_int4_elems = 1ull * num_experts * n * k; bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr, num_int4_elems); - TORCH_CHECK(ok, "unified_encode_int4b failed"); + STD_TORCH_CHECK(ok, "unified_encode_int4b failed"); // construct the layout once; assumes each expert has the same layout using LayoutType = LayoutB_Reordered; @@ -456,28 +476,28 @@ std::tuple encode_and_reorder_int4b( } // save the packed layout to torch tensor so we can re-use it - auto cpu_opts = - torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); - torch::Tensor layout_cpu = - torch::empty({num_experts, layout_width}, cpu_opts); + torch::stable::Tensor layout_cpu = torch::stable::empty( + {num_experts, layout_width}, torch::headeronly::ScalarType::Int, + std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU)); - int32_t* layout_data = layout_cpu.data_ptr(); + int32_t* layout_data = layout_cpu.mutable_data_ptr(); for (int i = 0; i < num_experts; ++i) { std::memcpy(layout_data + i * layout_width, // dst (int32*) &layout_B_reordered, // src (LayoutType*) sizeof(LayoutType)); // number of bytes } - torch::Tensor packed_layout = - layout_cpu.to(b_tensors.device(), /*non_blocking=*/false); + torch::stable::Tensor packed_layout = + torch::stable::to(layout_cpu, b_tensors.device(), + /*non_blocking=*/false); return {b_tensors_packed, packed_layout}; } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("cutlass_w4a8_moe_mm", &mm); - m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm)); + m.impl("cutlass_encode_and_reorder_int4b_grouped", + TORCH_BOX(&encode_and_reorder_int4b)); } } // namespace vllm::cutlass_w4a8_moe -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu similarity index 83% rename from csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu rename to csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu index f77af06cd..c2b8c0c00 100644 --- a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -3,14 +3,12 @@ // https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu // -#include -#include -#include +#include +#include +#include "libtorch_stable/torch_utils.h" #include "cutlass_extensions/torch_utils.hpp" #include "w4a8_utils.cuh" -#include "core/registration.h" - #include "cutlass/cutlass.h" #include @@ -161,31 +159,31 @@ struct W4A8GemmKernel { using StrideD = typename GemmKernelShuffled::StrideD; using StrideS = typename CollectiveMainloopShuffled::StrideScale; - static torch::Tensor mm(torch::Tensor const& A, - torch::Tensor const& B, // already packed - torch::Tensor const& group_scales, // already packed - int64_t group_size, - torch::Tensor const& channel_scales, - torch::Tensor const& token_scales, - std::optional const& maybe_out_type) { + static torch::stable::Tensor mm( + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, // already packed + torch::stable::Tensor const& group_scales, // already packed + int64_t group_size, torch::stable::Tensor const& channel_scales, + torch::stable::Tensor const& token_scales, + std::optional const& maybe_out_type) { // TODO: param validation int m = A.size(0); int k = A.size(1); int n = B.size(1); // safely cast group_size to int - TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits::max(), - "group_size out of supported range for int: ", group_size); + STD_TORCH_CHECK( + group_size > 0 && group_size <= std::numeric_limits::max(), + "group_size out of supported range for int: ", group_size); int const group_size_int = static_cast(group_size); // Allocate output - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + const torch::stable::accelerator::DeviceGuard device_guard( + A.get_device_index()); auto device = A.device(); - auto stream = at::cuda::getCurrentCUDAStream(device.index()); - torch::Tensor D = - torch::empty({m, n}, torch::TensorOptions() - .dtype(equivalent_scalar_type_v) - .device(device)); + auto stream = get_current_cuda_stream(device.index()); + torch::stable::Tensor D = torch::stable::empty( + {m, n}, equivalent_scalar_type_v, std::nullopt, device); // prepare arg pointers auto A_ptr = static_cast(A.const_data_ptr()); auto B_ptr = static_cast(B.const_data_ptr()); @@ -237,9 +235,9 @@ struct W4A8GemmKernel { // Workspace size_t workspace_size = GemmShuffled::get_workspace_size(arguments); - torch::Tensor workspace = - torch::empty(workspace_size, - torch::TensorOptions().dtype(torch::kU8).device(device)); + torch::stable::Tensor workspace = torch::stable::empty( + workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt, + device); // Run GEMM GemmShuffled gemm; @@ -269,14 +267,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_128x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; using Kernel_128x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; -torch::Tensor mm_dispatch(torch::Tensor const& A, - torch::Tensor const& B, // already packed - torch::Tensor const& group_scales, // already packed - int64_t group_size, - torch::Tensor const& channel_scales, - torch::Tensor const& token_scales, - std::optional const& maybe_out_type, - const std::string& schedule) { +torch::stable::Tensor mm_dispatch( + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, // already packed + torch::stable::Tensor const& group_scales, // already packed + int64_t group_size, torch::stable::Tensor const& channel_scales, + torch::stable::Tensor const& token_scales, + std::optional const& maybe_out_type, + const std::string& schedule) { if (schedule == "256x128_1x1x1") { return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, channel_scales, token_scales, @@ -318,17 +316,18 @@ torch::Tensor mm_dispatch(torch::Tensor const& A, channel_scales, token_scales, maybe_out_type); } - TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); + STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); return {}; } -torch::Tensor mm(torch::Tensor const& A, - torch::Tensor const& B, // already packed - torch::Tensor const& group_scales, // already packed - int64_t group_size, torch::Tensor const& channel_scales, - torch::Tensor const& token_scales, - std::optional const& maybe_out_type, - std::optional maybe_schedule) { +torch::stable::Tensor mm( + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, // already packed + torch::stable::Tensor const& group_scales, // already packed + int64_t group_size, torch::stable::Tensor const& channel_scales, + torch::stable::Tensor const& token_scales, + std::optional const& maybe_out_type, + std::optional maybe_schedule) { // requested a specific schedule if (maybe_schedule) { return mm_dispatch(A, B, group_scales, group_size, channel_scales, @@ -378,14 +377,15 @@ torch::Tensor mm(torch::Tensor const& A, // ---------------------------------------------------------------------------- // Pre-processing utils // ---------------------------------------------------------------------------- -torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { - TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(scales.is_contiguous()); - TORCH_CHECK(scales.is_cuda()); +torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) { + STD_TORCH_CHECK(scales.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); + STD_TORCH_CHECK(scales.is_contiguous()); + STD_TORCH_CHECK(scales.is_cuda()); - auto packed_scales = torch::empty( - {scales.numel() * ScalePackSize}, - torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); + auto packed_scales = + torch::stable::empty({scales.numel() * ScalePackSize}, + scales.scalar_type(), std::nullopt, scales.device()); auto scales_ptr = static_cast(scales.const_data_ptr()); auto packed_scales_ptr = static_cast*>( @@ -396,15 +396,16 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { return packed_scales; } -torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { - TORCH_CHECK(B.dtype() == torch::kInt32); - TORCH_CHECK(B.dim() == 2); +torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) { + STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int); + STD_TORCH_CHECK(B.dim() == 2); - torch::Tensor B_packed = torch::empty_like(B); + torch::stable::Tensor B_packed = torch::stable::empty_like(B); int k = B.size(0) * PackFactor; // logical k int n = B.size(1); - TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks"); + STD_TORCH_CHECK((n * k) % 32 == 0, + "need multiples of 32 int4s for 16B chunks"); auto B_ptr = static_cast(B.const_data_ptr()); auto B_packed_ptr = static_cast(B_packed.data_ptr()); @@ -415,16 +416,17 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); - TORCH_CHECK(ok, "unified_encode_int4b failed"); + STD_TORCH_CHECK(ok, "unified_encode_int4b failed"); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); return B_packed; } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("cutlass_w4a8_mm", &mm); - m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); - m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm)); + m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8)); + m.impl("cutlass_encode_and_reorder_int4b", + TORCH_BOX(&encode_and_reorder_int4b)); } -} // namespace vllm::cutlass_w4a8 \ No newline at end of file +} // namespace vllm::cutlass_w4a8 diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cu b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu similarity index 100% rename from csrc/quantization/cutlass_w4a8/w4a8_utils.cu rename to csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu diff --git a/csrc/quantization/cutlass_w4a8/w4a8_utils.cuh b/csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cuh similarity index 100% rename from csrc/quantization/cutlass_w4a8/w4a8_utils.cuh rename to csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cuh diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu similarity index 86% rename from csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu rename to csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index 3539096c9..49f2944f3 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -14,16 +14,15 @@ * limitations under the License. */ -#include +#include +#include "libtorch_stable/torch_utils.h" +#include "libtorch_stable/dispatch_utils.h" +#include "cuda_vec_utils.cuh" #include #include -#include -#include - #include -#include "dispatch_utils.h" #include "cuda_utils.h" #include "launch_bounds_utils.h" @@ -118,17 +117,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) } // namespace vllm -void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] - torch::Tensor& output_sf, - torch::Tensor& input, // [..., 2 * d] - torch::Tensor& input_sf) { +void silu_and_mul_nvfp4_quant_sm1xxa( + torch::stable::Tensor& output, // [..., d] + torch::stable::Tensor& output_sf, + torch::stable::Tensor& input, // [..., 2 * d] + torch::stable::Tensor& input_sf) { int32_t m = input.size(0); int32_t n = input.size(1) / 2; - TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || - input.scalar_type() == at::ScalarType::BFloat16, - "Unsupported input data type for quantize_to_fp4."); + STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + STD_TORCH_CHECK( + input.scalar_type() == torch::headeronly::ScalarType::Half || + input.scalar_type() == torch::headeronly::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); @@ -136,8 +137,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] auto input_sf_ptr = static_cast(input_sf.data_ptr()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + auto stream = get_current_cuda_stream(input.get_device_index()); dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); @@ -149,7 +151,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); - VLLM_DISPATCH_HALF_TYPES( + VLLM_STABLE_DISPATCH_HALF_TYPES( input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { using cuda_type = vllm::CUDATypeConverter::Type; auto input_ptr = static_cast(input.data_ptr()); diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu similarity index 66% rename from csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu rename to csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu index ae8ef1bf9..b22308d25 100644 --- a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -14,14 +14,12 @@ * limitations under the License. */ -#include "core/registration.h" +#include +#include +#include "libtorch_stable/torch_utils.h" -#include #include -#include -#include -#include #include "cutlass_extensions/common.hpp" #include "cute/tensor.hpp" @@ -122,7 +120,7 @@ __global__ void __get_group_gemm_starts( #define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \ TENSOR_C_TYPE, C_TYPE, LayoutSFA, \ LayoutSFB, ScaleConfig) \ - else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \ __get_group_gemm_starts \ <<<1, num_experts, 0, stream>>>( \ @@ -150,50 +148,64 @@ __global__ void __get_group_gemm_starts( } template -void run_get_group_gemm_starts( - const torch::Tensor& a_starts, const torch::Tensor& b_starts, - const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts, - const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts, - const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, - const torch::Tensor& a_strides, const torch::Tensor& b_strides, - const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val, - int64_t c_stride_val, - /*these are used for their base addresses*/ - torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, - torch::Tensor const& out_tensors, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, torch::Tensor const& alphas, - torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets, - torch::Tensor const& problem_sizes, int M, int N, int K) { +void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts, + const torch::stable::Tensor& b_starts, + const torch::stable::Tensor& out_starts, + const torch::stable::Tensor& a_scales_starts, + const torch::stable::Tensor& b_scales_starts, + const torch::stable::Tensor& alpha_starts, + const torch::stable::Tensor& layout_sfa, + const torch::stable::Tensor& layout_sfb, + const torch::stable::Tensor& a_strides, + const torch::stable::Tensor& b_strides, + const torch::stable::Tensor& c_strides, + int64_t a_stride_val, int64_t b_stride_val, + int64_t c_stride_val, + /*these are used for their base addresses*/ + torch::stable::Tensor const& a_tensors, + torch::stable::Tensor const& b_tensors, + torch::stable::Tensor const& out_tensors, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& alphas, + torch::stable::Tensor const& expert_offsets, + torch::stable::Tensor const& sf_offsets, + torch::stable::Tensor const& problem_sizes, + int M, int N, int K) { int num_experts = (int)expert_offsets.size(0); - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + auto stream = get_current_cuda_stream(a_tensors.get_device_index()); - TORCH_CHECK(out_tensors.size(1) == N, - "Output tensor shape doesn't match expected shape"); - TORCH_CHECK(K / 2 == b_tensors.size(2), - "b_tensors(dim = 2) and a_tensors(dim = 1) trailing" - " dimension must match"); + STD_TORCH_CHECK(out_tensors.size(1) == N, + "Output tensor shape doesn't match expected shape"); + STD_TORCH_CHECK(K / 2 == b_tensors.size(2), + "b_tensors(dim = 2) and a_tensors(dim = 1) trailing" + " dimension must match"); if (false) { } //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, // ScaleConfig) __CALL_GET_STARTS_KERNEL_BLOCKSCALE( - cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16, - cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) + cutlass::float_e2m1_t, cutlass::float_ue4m3_t, + torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA, + LayoutSFB, ScaleConfig) __CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t, - cutlass::float_ue4m3_t, torch::kFloat16, - half, LayoutSFA, LayoutSFB, ScaleConfig) + cutlass::float_ue4m3_t, + torch::headeronly::ScalarType::Half, half, + LayoutSFA, LayoutSFB, ScaleConfig) else { - TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); } } template void run_fp4_blockwise_scaled_group_mm_sm100( - torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, - const torch::Tensor& alphas, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, - int N, int K) { + torch::stable::Tensor& output, const torch::stable::Tensor& a, + const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale, + const torch::stable::Tensor& b_blockscales, + const torch::stable::Tensor& alphas, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& sf_offsets, int M, int N, int K) { using ProblemShape = cutlass::gemm::GroupProblemShape>; using ElementType = cutlass::float_e2m1_t; @@ -272,20 +284,40 @@ void run_fp4_blockwise_scaled_group_mm_sm100( using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; int num_experts = static_cast(expert_offsets.size(0)); - auto options_int = - torch::TensorOptions().dtype(torch::kInt64).device(a.device()); - torch::Tensor a_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_ptrs = torch::empty(num_experts, options_int); - torch::Tensor out_ptrs = torch::empty(num_experts, options_int); - torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); - torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); - torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); - torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); - torch::Tensor a_strides1 = torch::empty(num_experts, options_int); - torch::Tensor b_strides1 = torch::empty(num_experts, options_int); - torch::Tensor c_strides1 = torch::empty(num_experts, options_int); + torch::stable::Tensor a_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor b_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor out_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor a_scales_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor b_scales_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor alpha_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor layout_sfa = torch::stable::empty( + {num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt, + a.device()); + torch::stable::Tensor layout_sfb = torch::stable::empty( + {num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt, + a.device()); + torch::stable::Tensor a_strides1 = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor b_strides1 = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor c_strides1 = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); run_get_group_gemm_starts( a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, @@ -308,7 +340,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100( typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; typename Gemm::GemmKernel::TileSchedulerArguments scheduler; scheduler.raster_order = RasterOrderOptions::AlongM; - hw_info.device_id = a.get_device(); + hw_info.device_id = a.get_device_index(); static std::unordered_map cached_sm_counts; if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { cached_sm_counts[hw_info.device_id] = @@ -350,32 +382,35 @@ void run_fp4_blockwise_scaled_group_mm_sm100( scheduler}; size_t workspace_size = Gemm::get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + auto workspace = + torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, + std::nullopt, a.device()); + const cudaStream_t stream = get_current_cuda_stream(a.get_device_index()); auto can_implement_status = gemm_op.can_implement(args); - TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM: status=", (int)can_implement_status); + STD_TORCH_CHECK( + can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM: status=", (int)can_implement_status); // Run the GEMM auto status = gemm_op.initialize(args, workspace.data_ptr()); - TORCH_CHECK(status == cutlass::Status::kSuccess, - "Failed to initialize GEMM: status=", (int)status, - " workspace_size=", workspace_size, " num_experts=", num_experts, - " M=", M, " N=", N, " K=", K); + STD_TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM: status=", (int)status, + " workspace_size=", workspace_size, + " num_experts=", num_experts, " M=", M, " N=", N, " K=", K); status = gemm_op.run(args, workspace.data_ptr(), stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); + STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } void run_fp4_blockwise_scaled_group_mm_sm120( - torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, - const torch::Tensor& alphas, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, - int N, int K) { + torch::stable::Tensor& output, const torch::stable::Tensor& a, + const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale, + const torch::stable::Tensor& b_blockscales, + const torch::stable::Tensor& alphas, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& sf_offsets, int M, int N, int K) { using ProblemShape = cutlass::gemm::GroupProblemShape>; using ElementType = cutlass::float_e2m1_t; @@ -446,20 +481,40 @@ void run_fp4_blockwise_scaled_group_mm_sm120( using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; int num_experts = static_cast(expert_offsets.size(0)); - auto options_int = - torch::TensorOptions().dtype(torch::kInt64).device(a.device()); - torch::Tensor a_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_ptrs = torch::empty(num_experts, options_int); - torch::Tensor out_ptrs = torch::empty(num_experts, options_int); - torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); - torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); - torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); - torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); - torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); - torch::Tensor a_strides1 = torch::empty(num_experts, options_int); - torch::Tensor b_strides1 = torch::empty(num_experts, options_int); - torch::Tensor c_strides1 = torch::empty(num_experts, options_int); + torch::stable::Tensor a_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor b_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor out_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor a_scales_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor b_scales_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor alpha_ptrs = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor layout_sfa = torch::stable::empty( + {num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt, + a.device()); + torch::stable::Tensor layout_sfb = torch::stable::empty( + {num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt, + a.device()); + torch::stable::Tensor a_strides1 = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor b_strides1 = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); + torch::stable::Tensor c_strides1 = + torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long, + std::nullopt, a.device()); run_get_group_gemm_starts( a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, @@ -480,7 +535,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120( using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; typename Gemm::GemmKernel::TileSchedulerArguments scheduler; scheduler.raster_order = RasterOrderOptions::AlongM; - hw_info.device_id = a.get_device(); + hw_info.device_id = a.get_device_index(); static std::unordered_map cached_sm_counts; if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { cached_sm_counts[hw_info.device_id] = @@ -523,33 +578,36 @@ void run_fp4_blockwise_scaled_group_mm_sm120( scheduler}; size_t workspace_size = Gemm::get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + auto workspace = + torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, + std::nullopt, a.device()); + const cudaStream_t stream = get_current_cuda_stream(a.get_device_index()); auto can_implement_status = gemm_op.can_implement(args); - TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, - "Failed to implement GEMM: status=", (int)can_implement_status); + STD_TORCH_CHECK( + can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM: status=", (int)can_implement_status); // Run the GEMM auto status = gemm_op.initialize(args, workspace.data_ptr()); - TORCH_CHECK(status == cutlass::Status::kSuccess, - "Failed to initialize GEMM: status=", (int)status, - " workspace_size=", workspace_size, " num_experts=", num_experts, - " M=", M, " N=", N, " K=", K); + STD_TORCH_CHECK(status == cutlass::Status::kSuccess, + "Failed to initialize GEMM: status=", (int)status, + " workspace_size=", workspace_size, + " num_experts=", num_experts, " M=", M, " N=", N, " K=", K); status = gemm_op.run(args, workspace.data_ptr(), stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); + STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); } template void run_fp4_blockwise_scaled_group_mm( - torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, - const torch::Tensor& alphas, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, - int N, int K) { + torch::stable::Tensor& output, const torch::stable::Tensor& a, + const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale, + const torch::stable::Tensor& b_blockscales, + const torch::stable::Tensor& alphas, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& sf_offsets, int M, int N, int K) { int32_t version_num = get_sm_version_num(); #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 if (version_num >= 120 && version_num < 130) { @@ -567,7 +625,7 @@ void run_fp4_blockwise_scaled_group_mm( return; } #endif - TORCH_CHECK_NOT_IMPLEMENTED( + STD_TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ", version_num, ". Required capability: 100 or 120"); @@ -575,26 +633,31 @@ void run_fp4_blockwise_scaled_group_mm( #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) -constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; -constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; +constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; +constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; #endif -#define CHECK_TYPE(x, st, m) \ - TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TYPE(x, st, m) \ + STD_TORCH_CHECK(x.scalar_type() == st, \ + ": Inconsistency of torch::stable::Tensor type:", m) #define CHECK_TH_CUDA(x, m) \ - TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") + STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") #define CHECK_CONTIGUOUS(x, m) \ - TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") + STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") #define CHECK_INPUT(x, st, m) \ CHECK_TH_CUDA(x, m); \ CHECK_CONTIGUOUS(x, m); \ CHECK_TYPE(x, st, m) -void cutlass_fp4_group_mm( - torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, - const torch::Tensor& alphas, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { +void cutlass_fp4_group_mm(torch::stable::Tensor& output, + const torch::stable::Tensor& a, + const torch::stable::Tensor& b, + const torch::stable::Tensor& a_blockscale, + const torch::stable::Tensor& b_blockscales, + const torch::stable::Tensor& alphas, + const torch::stable::Tensor& problem_sizes, + const torch::stable::Tensor& expert_offsets, + const torch::stable::Tensor& sf_offsets) { #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) // Input validation @@ -602,30 +665,34 @@ void cutlass_fp4_group_mm( CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale"); CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales"); - CHECK_INPUT(alphas, at::ScalarType::Float, "alphas"); + CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas"); - TORCH_CHECK(a_blockscale.dim() == 2, - "expected a_blockscale to be of shape [num_experts, rounded_m," - " k // group_size], observed rank: ", - a_blockscale.dim()) - TORCH_CHECK(b_blockscales.dim() == 3, - "expected b_blockscale to be of shape: " - " [num_experts, n, k // group_size], observed rank: ", - b_blockscales.dim()) - TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor"); - TORCH_CHECK(problem_sizes.size(1) == 3, - "problem_sizes must have the shape (num_experts, 3)"); - TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), - "Number of experts in problem_sizes must match expert_offsets"); - TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, - "problem_sizes must be int32."); + STD_TORCH_CHECK( + a_blockscale.dim() == 2, + "expected a_blockscale to be of shape [num_experts, rounded_m," + " k // group_size], observed rank: ", + a_blockscale.dim()) + STD_TORCH_CHECK(b_blockscales.dim() == 3, + "expected b_blockscale to be of shape: " + " [num_experts, n, k // group_size], observed rank: ", + b_blockscales.dim()) + STD_TORCH_CHECK(problem_sizes.dim() == 2, + "problem_sizes must be a 2D tensor"); + STD_TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have the shape (num_experts, 3)"); + STD_TORCH_CHECK( + problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + STD_TORCH_CHECK( + problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int, + "problem_sizes must be int32."); int M = static_cast(a.size(0)); int N = static_cast(b.size(1)); int E = static_cast(b.size(0)); int K = static_cast(2 * b.size(2)); - if (output.scalar_type() == torch::kBFloat16) { + if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) { run_fp4_blockwise_scaled_group_mm( output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, expert_offsets, sf_offsets, M, N, K); @@ -633,7 +700,7 @@ void cutlass_fp4_group_mm( #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 int32_t version_num = get_sm_version_num(); if (version_num >= 120 && version_num < 130) { - TORCH_CHECK_NOT_IMPLEMENTED( + STD_TORCH_CHECK_NOT_IMPLEMENTED( false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ", output.scalar_type()); } @@ -643,7 +710,7 @@ void cutlass_fp4_group_mm( expert_offsets, sf_offsets, M, N, K); } #else - TORCH_CHECK_NOT_IMPLEMENTED( + STD_TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_fp4_group_mm kernel, vLLM must " "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 " @@ -651,6 +718,6 @@ void cutlass_fp4_group_mm( #endif } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm); +STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) { + m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm)); } diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu similarity index 84% rename from csrc/quantization/fp4/nvfp4_experts_quant.cu rename to csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu index 3162b6cdb..f90bd543a 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu @@ -14,16 +14,15 @@ * limitations under the License. */ -#include +#include +#include "libtorch_stable/torch_utils.h" +#include "libtorch_stable/dispatch_utils.h" +#include "cuda_vec_utils.cuh" #include #include -#include -#include - #include -#include "dispatch_utils.h" #include "cuda_utils.h" #include "nvfp4_utils.cuh" @@ -327,25 +326,28 @@ void quant_impl(void* output, void* output_scale, void* input, } // namespace vllm /*Quantization entry for fp4 experts quantization*/ -#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_TH_CUDA(x, m) \ + STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") #define CHECK_CONTIGUOUS(x, m) \ - TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") + STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") #define CHECK_INPUT(x, m) \ CHECK_TH_CUDA(x, m); \ CHECK_CONTIGUOUS(x, m); -constexpr auto HALF = at::ScalarType::Half; -constexpr auto BF16 = at::ScalarType::BFloat16; -constexpr auto FLOAT = at::ScalarType::Float; -constexpr auto INT = at::ScalarType::Int; -constexpr auto UINT8 = at::ScalarType::Byte; +constexpr auto HALF = torch::headeronly::ScalarType::Half; +constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16; +constexpr auto FLOAT = torch::headeronly::ScalarType::Float; +constexpr auto INT = torch::headeronly::ScalarType::Int; +constexpr auto UINT8 = torch::headeronly::ScalarType::Byte; // Common validation for fp4 experts quantization entry points. static void validate_fp4_experts_quant_inputs( - torch::Tensor const& output, torch::Tensor const& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk, + torch::stable::Tensor const& output, + torch::stable::Tensor const& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk, int64_t k) { CHECK_INPUT(output, "output"); CHECK_INPUT(output_scale, "output_scale"); @@ -354,41 +356,42 @@ static void validate_fp4_experts_quant_inputs( CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts"); CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts"); - TORCH_CHECK(output.dim() == 2); - TORCH_CHECK(output_scale.dim() == 2); - TORCH_CHECK(input.dim() == 2); - TORCH_CHECK(input_global_scale.dim() == 1); - TORCH_CHECK(input_offset_by_experts.dim() == 1); - TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + STD_TORCH_CHECK(output.dim() == 2); + STD_TORCH_CHECK(output_scale.dim() == 2); + STD_TORCH_CHECK(input.dim() == 2); + STD_TORCH_CHECK(input_global_scale.dim() == 1); + STD_TORCH_CHECK(input_offset_by_experts.dim() == 1); + STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); - TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); - TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); - TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); - TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); // output is uint8 (two nvfp4 values are packed into one uint8) // output_scale is int32 (four fp8 values are packed into one int32) - TORCH_CHECK(output.scalar_type() == UINT8); - TORCH_CHECK(output_scale.scalar_type() == INT); + STD_TORCH_CHECK(output.scalar_type() == UINT8); + STD_TORCH_CHECK(output_scale.scalar_type() == INT); const int BLOCK_SIZE = 16; - TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); auto n_experts = input_global_scale.size(0); - TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); - TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); - TORCH_CHECK(output.size(0) == m_topk); - TORCH_CHECK(output.size(1) == k / 2); + STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + STD_TORCH_CHECK(output.size(0) == m_topk); + STD_TORCH_CHECK(output.size(1) == k / 2); int scales_k = k / BLOCK_SIZE; // 4 means the swizzle requirement by nvidia nvfp4. int padded_k = (scales_k + (4 - 1)) / 4 * 4; // 4 means 4 fp8 values are packed into one int32 - TORCH_CHECK(output_scale.size(1) * 4 == padded_k); + STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k); } void scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { + torch::stable::Tensor& output, torch::stable::Tensor& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts) { auto m_topk = input.size(0); auto k = input.size(1); @@ -397,11 +400,11 @@ void scaled_fp4_experts_quant_sm1xxa( output_scale_offset_by_experts, m_topk, k); auto n_experts = input_global_scale.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(input.get_device()); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(input.get_device_index()); - VLLM_DISPATCH_HALF_TYPES( + VLLM_STABLE_DISPATCH_HALF_TYPES( input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { using cuda_type = vllm::CUDATypeConverter::Type; vllm::quant_impl( @@ -413,14 +416,15 @@ void scaled_fp4_experts_quant_sm1xxa( } void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { + torch::stable::Tensor& output, torch::stable::Tensor& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts) { auto m_topk = input.size(0); // Input has gate || up layout, so k = input.size(1) / 2 auto k_times_2 = input.size(1); - TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); + STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); auto k = k_times_2 / 2; validate_fp4_experts_quant_inputs(output, output_scale, input, @@ -428,11 +432,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( output_scale_offset_by_experts, m_topk, k); auto n_experts = input_global_scale.size(0); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(input.get_device()); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(input.get_device_index()); - VLLM_DISPATCH_HALF_TYPES( + VLLM_STABLE_DISPATCH_HALF_TYPES( input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] { using cuda_type = vllm::CUDATypeConverter::Type; vllm::quant_impl( diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu new file mode 100644 index 000000000..4adc62437 --- /dev/null +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "libtorch_stable/torch_utils.h" + +#include "cutlass_extensions/common.hpp" +#include "nvfp4_utils.cuh" + +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output, + torch::stable::Tensor const& input, + torch::stable::Tensor const& output_sf, + torch::stable::Tensor const& input_sf, + bool is_sf_swizzled_layout); +#endif + +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void scaled_fp4_experts_quant_sm1xxa( + torch::stable::Tensor& output, torch::stable::Tensor& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts); +#endif + +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output, + torch::stable::Tensor& output_sf, + torch::stable::Tensor& input, + torch::stable::Tensor& input_sf); +#endif + +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) +void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( + torch::stable::Tensor& output, torch::stable::Tensor& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts); +#endif + +static bool nvfp4_quant_sm_supported() { + const int32_t sm = get_sm_version_num(); +#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + if (sm >= 100 && sm < 120) return true; +#endif +#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 + if (sm >= 120 && sm < 130) return true; +#endif + return false; +} + +void scaled_fp4_quant_out(torch::stable::Tensor const& input, + torch::stable::Tensor const& input_sf, + bool is_sf_swizzled_layout, + torch::stable::Tensor& output, + torch::stable::Tensor& output_sf) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + STD_TORCH_CHECK(nvfp4_quant_sm_supported(), + "No compiled nvfp4 quantization kernel for SM ", + get_sm_version_num(), + ". Recompile with the appropriate CUDA arch."); + return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf, + is_sf_swizzled_layout); +#endif + STD_TORCH_CHECK_NOT_IMPLEMENTED(false, + "No compiled nvfp4 quantization kernel"); +} + +std::tuple scaled_fp4_quant_func( + torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf, + bool is_sf_swizzled_layout) { + int64_t n = input.size(-1); + int64_t m = input.numel() / n; + auto device = input.device(); + + // Two fp4 values packed into a uint8 + auto output = torch::stable::empty( + {m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device); + + torch::stable::Tensor output_sf; + if (is_sf_swizzled_layout) { + auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n); + output_sf = torch::stable::empty( + {sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device); + } else { + output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE}, + torch::headeronly::ScalarType::Byte, + std::nullopt, device); + } + + scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output, + output_sf); + return {output, output_sf}; +} + +void scaled_fp4_experts_quant( + torch::stable::Tensor& output, torch::stable::Tensor& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + STD_TORCH_CHECK(nvfp4_quant_sm_supported(), + "No compiled nvfp4 experts quantization kernel for SM ", + get_sm_version_num(), + ". Recompile with the appropriate CUDA arch."); + return scaled_fp4_experts_quant_sm1xxa( + output, output_scale, input, input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts); +#endif + STD_TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled nvfp4 experts quantization kernel"); +} + +void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output, + torch::stable::Tensor& output_sf, + torch::stable::Tensor& input, + torch::stable::Tensor& input_sf) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + STD_TORCH_CHECK(nvfp4_quant_sm_supported(), + "No compiled silu_and_mul nvfp4 quantization kernel for SM ", + get_sm_version_num(), + ". Recompile with the appropriate CUDA arch."); + return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf); +#endif + STD_TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled silu_and_mul nvfp4 quantization kernel"); +} + +void silu_and_mul_scaled_fp4_experts_quant( + torch::stable::Tensor& output, torch::stable::Tensor& output_scale, + torch::stable::Tensor const& input, + torch::stable::Tensor const& input_global_scale, + torch::stable::Tensor const& input_offset_by_experts, + torch::stable::Tensor const& output_scale_offset_by_experts) { +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + STD_TORCH_CHECK(nvfp4_quant_sm_supported(), + "No compiled silu_and_mul nvfp4 experts quantization kernel " + "for SM ", + get_sm_version_num(), + ". Recompile with the appropriate CUDA arch."); + return silu_and_mul_scaled_fp4_experts_quant_sm1xxa( + output, output_scale, input, input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts); +#endif + STD_TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled silu_and_mul nvfp4 experts quantization kernel"); +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu similarity index 80% rename from csrc/quantization/fp4/nvfp4_quant_kernels.cu rename to csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu index 773047c22..beb1287c8 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu @@ -14,16 +14,16 @@ * limitations under the License. */ -#include +#include #include #include -#include -#include - #include -#include "dispatch_utils.h" + +#include "libtorch_stable/torch_utils.h" +#include "libtorch_stable/dispatch_utils.h" +#include "cuda_vec_utils.cuh" #include "cuda_utils.h" #include "launch_bounds_utils.h" @@ -173,18 +173,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) } // namespace vllm -void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, - torch::Tensor const& input, - torch::Tensor const& output_sf, - torch::Tensor const& input_sf, +void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output, + torch::stable::Tensor const& input, + torch::stable::Tensor const& output_sf, + torch::stable::Tensor const& input_sf, bool is_sf_swizzled_layout) { int32_t m = input.size(0); int32_t n = input.size(1); - TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || - input.scalar_type() == at::ScalarType::BFloat16, - "Unsupported input data type for quantize_to_fp4."); + STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + STD_TORCH_CHECK( + input.scalar_type() == torch::headeronly::ScalarType::Half || + input.scalar_type() == torch::headeronly::ScalarType::BFloat16, + "Unsupported input data type for quantize_to_fp4."); int multiProcessorCount = get_device_attribute(cudaDevAttrMultiProcessorCount, -1); @@ -192,8 +193,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, auto input_sf_ptr = static_cast(input_sf.data_ptr()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + const torch::stable::accelerator::DeviceGuard device_guard( + input.get_device_index()); + auto stream = get_current_cuda_stream(input.get_device_index()); int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); @@ -213,15 +215,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); - VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - // NOTE: We don't support e8m0 scales at this moment. - vllm::cvt_fp16_to_fp4<<>>( - m, n, num_padded_cols, input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); + VLLM_STABLE_DISPATCH_HALF_TYPES( + input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::cvt_fp16_to_fp4<<>>( + m, n, num_padded_cols, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); } else { int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD; int grid_y = vllm::div_round_up(num_packed_cols, static_cast(block.x)); @@ -229,15 +231,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); - VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { - using cuda_type = vllm::CUDATypeConverter::Type; - auto input_ptr = static_cast(input.data_ptr()); - // NOTE: We don't support e8m0 scales at this moment. - vllm::cvt_fp16_to_fp4_sf_major - <<>>(m, n, sf_n_unpadded, num_packed_cols, - input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); + VLLM_STABLE_DISPATCH_HALF_TYPES( + input.scalar_type(), "nvfp4_quant_kernel", [&] { + using cuda_type = vllm::CUDATypeConverter::Type; + auto input_ptr = static_cast(input.data_ptr()); + vllm::cvt_fp16_to_fp4_sf_major + <<>>( + m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr, + reinterpret_cast(output_ptr), + reinterpret_cast(sf_out)); + }); } } diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu similarity index 55% rename from csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu rename to csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu index 4985bf1eb..d7b2a18e2 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -14,32 +14,39 @@ * limitations under the License. */ -#include -#include +#include + +#include "libtorch_stable/torch_utils.h" + #include "cutlass_extensions/common.hpp" #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 -void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, - torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha); +void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha); #endif #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 -void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, - torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha); +void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha); #endif -void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A, - const torch::Tensor& B, const torch::Tensor& A_sf, - const torch::Tensor& B_sf, - const torch::Tensor& alpha) { - // Make sure we’re on A’s device. - const c10::cuda::OptionalCUDAGuard device_guard(device_of(A)); +void cutlass_scaled_fp4_mm(torch::stable::Tensor& D, + const torch::stable::Tensor& A, + const torch::stable::Tensor& B, + const torch::stable::Tensor& A_sf, + const torch::stable::Tensor& B_sf, + const torch::stable::Tensor& alpha) { + // Make sure we're on A's device. + const torch::stable::accelerator::DeviceGuard device_guard( + A.get_device_index()); const int32_t sm = get_sm_version_num(); #if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 @@ -56,8 +63,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A, } #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm, - ". Recompile with CUDA >= 12.8 and CC >= 100."); + STD_TORCH_CHECK_NOT_IMPLEMENTED( + false, "No compiled nvfp4 mm kernel for SM ", sm, + ". Recompile with CUDA >= 12.8 and CC >= 100."); } bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu similarity index 65% rename from csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu rename to csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 5bc4c38a2..c8e6c9250 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -14,10 +14,9 @@ * limitations under the License. */ -#include +#include -#include -#include +#include "libtorch_stable/torch_utils.h" #include "cutlass_extensions/common.hpp" @@ -127,8 +126,9 @@ struct Fp4GemmSm100 { template typename Config::Gemm::Arguments args_from_options( - at::Tensor& D, at::Tensor const& A, at::Tensor const& B, - at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, + torch::stable::Tensor& D, torch::stable::Tensor const& A, + torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha, int64_t M, int64_t N, int64_t K) { using ElementA = typename Config::Gemm::ElementA; using ElementB = typename Config::Gemm::ElementB; @@ -174,19 +174,20 @@ typename Config::Gemm::Arguments args_from_options( } template -void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, - at::Tensor const& A_sf, at::Tensor const& B_sf, - at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, - cudaStream_t stream) { +void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A, + torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha, int64_t m, int64_t n, + int64_t k, cudaStream_t stream) { typename Config::Gemm gemm; auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); size_t workspace_size = Config::Gemm::get_workspace_size(arguments); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); - auto workspace = torch::empty(workspace_size, workspace_options); + auto workspace = + torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, + std::nullopt, A.device()); CUTLASS_CHECK(gemm.can_implement(arguments)); @@ -197,12 +198,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, // Dispatch function to select appropriate config based on M template -void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, - torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha, int64_t m, int64_t n, - int64_t k, cudaStream_t stream) { +void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha, int64_t m, + int64_t n, int64_t k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); if (mp2 <= 16) { @@ -222,61 +224,65 @@ void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, #else template -void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, - torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha, int64_t m, int64_t n, - int64_t k, cudaStream_t stream) { - TORCH_CHECK(false, - "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " - "a CUTLASS 3.8 source directory to enable support."); +void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha, int64_t m, + int64_t n, int64_t k, cudaStream_t stream) { + STD_TORCH_CHECK(false, + "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " + "a CUTLASS 3.8 source directory to enable support."); } #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) -#define CHECK_TYPE(x, st, m) \ - TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TYPE(x, st, m) \ + STD_TORCH_CHECK(x.scalar_type() == st, \ + ": Inconsistency of torch::stable::Tensor type:", m) #define CHECK_TH_CUDA(x, m) \ - TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") + STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") #define CHECK_CONTIGUOUS(x, m) \ - TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") + STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") #define CHECK_INPUT(x, st, m) \ CHECK_TH_CUDA(x, m); \ CHECK_CONTIGUOUS(x, m); \ CHECK_TYPE(x, st, m) -constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; -constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; +constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; +constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; -void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, - torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha) { +void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha) { CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); - CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha"); - TORCH_CHECK(A.dim() == 2, "a must be a matrix"); - TORCH_CHECK(B.dim() == 2, "b must be a matrix"); - TORCH_CHECK(A.sizes()[1] == B.sizes()[1], - "a and b shapes cannot be multiplied (", A.sizes()[0], "x", - A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); + STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + STD_TORCH_CHECK(A.size(1) == B.size(1), + "a and b shapes cannot be multiplied (", A.size(0), "x", + A.size(1), " and ", B.size(0), "x", B.size(1), ")"); - auto const m = A.sizes()[0]; - auto const n = B.sizes()[0]; - auto const k = A.sizes()[1] * 2; + auto const m = A.size(0); + auto const n = B.size(0); + auto const k = A.size(1) * 2; constexpr int alignment = 32; - TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, - ", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], - "), k: ", k, "."); - TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, - ", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); + STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", + alignment, ", but got a shape: (", A.size(0), "x", A.size(1), + "), k: ", k, "."); + STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", + alignment, ", but got b shape: (", B.size(0), "x", B.size(1), + ")."); auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; int rounded_m = round_up(m, 128); @@ -285,33 +291,34 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, // integer. int rounded_k = round_up(k / 16, 4); - TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); - TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); - TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], - "scale_a and scale_b shapes cannot be multiplied (", - A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], - "x", B_sf.sizes()[1], ")"); - TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, - "scale_a must be padded and swizzled to a shape (", rounded_m, - "x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", - A_sf.sizes()[1], ")"); - TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, - "scale_b must be padded and swizzled to a shape (", rounded_n, - "x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", - B_sf.sizes()[1], ")"); + STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1), + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x", + B_sf.size(1), ")"); + STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k, + "scale_a must be padded and swizzled to a shape (", rounded_m, + "x", rounded_k, "), but got a shape (", A_sf.size(0), "x", + A_sf.size(1), ")"); + STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k, + "scale_b must be padded and swizzled to a shape (", rounded_n, + "x", rounded_k, "), but got a shape (", B_sf.size(0), "x", + B_sf.size(1), ")"); - auto out_dtype = D.dtype(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + auto out_dtype = D.scalar_type(); + const torch::stable::accelerator::DeviceGuard device_guard( + A.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(A.get_device_index()); - if (out_dtype == at::ScalarType::Half) { + if (out_dtype == torch::headeronly::ScalarType::Half) { cutlass_fp4_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else if (out_dtype == at::ScalarType::BFloat16) { + } else if (out_dtype == torch::headeronly::ScalarType::BFloat16) { cutlass_fp4_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { - TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype, - ")"); + STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", + out_dtype, ")"); } } diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu similarity index 58% rename from csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu rename to csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu index 89de23b76..b500ae5a0 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu @@ -14,10 +14,9 @@ * limitations under the License. */ -#include +#include -#include -#include +#include "libtorch_stable/torch_utils.h" #include "cutlass_extensions/common.hpp" @@ -34,19 +33,20 @@ using namespace cute; -#define CHECK_TYPE(x, st, m) \ - TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TYPE(x, st, m) \ + STD_TORCH_CHECK(x.scalar_type() == st, \ + ": Inconsistency of torch::stable::Tensor type:", m) #define CHECK_TH_CUDA(x, m) \ - TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") + STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") #define CHECK_CONTIGUOUS(x, m) \ - TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") + STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") #define CHECK_INPUT(x, st, m) \ CHECK_TH_CUDA(x, m); \ CHECK_CONTIGUOUS(x, m); \ CHECK_TYPE(x, st, m) -constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; -constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; +constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; +constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; struct sm120_fp4_config_M256 { using ClusterShape = Shape<_1, _1, _1>; @@ -109,12 +109,13 @@ struct Fp4GemmSm120 { }; template -typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A, - at::Tensor const& B, - at::Tensor const& A_sf, - at::Tensor const& B_sf, - torch::Tensor const& alpha, int M, - int N, int K) { +typename Gemm::Arguments args_from_options(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha, + int M, int N, int K) { using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementD = typename Gemm::ElementD; @@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A, } template -void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, - at::Tensor const& A_sf, at::Tensor const& B_sf, - torch::Tensor const& alpha, int M, int N, int K, +void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A, + torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha, int M, int N, int K, cudaStream_t stream) { Gemm gemm; auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, M, N, K); size_t workspace_size = Gemm::get_workspace_size(arguments); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); - auto workspace = torch::empty(workspace_size, workspace_options); + auto workspace = + torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, + std::nullopt, A.device()); CUTLASS_CHECK(gemm.can_implement(arguments)); @@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); } -void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, - torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha, int m, int n, - int k, cudaStream_t stream) { +void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha, int m, + int n, int k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); if (mp2 <= 256) { runGemm::Gemm>( @@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, } } -void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, - torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha, int m, int n, - int k, cudaStream_t stream) { +void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha, int m, + int n, int k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); if (mp2 <= 256) { runGemm::Gemm>( @@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, } } -void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, - torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha) { +void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D, + torch::stable::Tensor const& A, + torch::stable::Tensor const& B, + torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, + torch::stable::Tensor const& alpha) { #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); @@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); - CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha"); - TORCH_CHECK(A.dim() == 2, "a must be a matrix"); - TORCH_CHECK(B.dim() == 2, "b must be a matrix"); - TORCH_CHECK(A.sizes()[1] == B.sizes()[1], - "a and b shapes cannot be multiplied (", A.sizes()[0], "x", - A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); + STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + STD_TORCH_CHECK(A.size(1) == B.size(1), + "a and b shapes cannot be multiplied (", A.size(0), "x", + A.size(1), " and ", B.size(0), "x", B.size(1), ")"); - auto const m = A.sizes()[0]; - auto const n = B.sizes()[0]; - auto const k = A.sizes()[1] * 2; + auto const m = A.size(0); + auto const n = B.size(0); + auto const k = A.size(1) * 2; constexpr int alignment = 32; - TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, - ", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], - "), k: ", k, "."); - TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, - ", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); + STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", + alignment, ", but got a shape: (", A.size(0), "x", A.size(1), + "), k: ", k, "."); + STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", + alignment, ", but got b shape: (", B.size(0), "x", B.size(1), + ")."); auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; int rounded_m = round_up(m, 128); @@ -248,38 +254,39 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, // integer. int rounded_k = round_up(k / 16, 4); - TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); - TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); - TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], - "scale_a and scale_b shapes cannot be multiplied (", - A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], - "x", B_sf.sizes()[1], ")"); - TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, - "scale_a must be padded and swizzled to a shape (", rounded_m, - "x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", - A_sf.sizes()[1], ")"); - TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, - "scale_b must be padded and swizzled to a shape (", rounded_n, - "x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", - B_sf.sizes()[1], ")"); + STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1), + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x", + B_sf.size(1), ")"); + STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k, + "scale_a must be padded and swizzled to a shape (", rounded_m, + "x", rounded_k, "), but got a shape (", A_sf.size(0), "x", + A_sf.size(1), ")"); + STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k, + "scale_b must be padded and swizzled to a shape (", rounded_n, + "x", rounded_k, "), but got a shape (", B_sf.size(0), "x", + B_sf.size(1), ")"); - auto out_dtype = D.dtype(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + auto out_dtype = D.scalar_type(); + const torch::stable::accelerator::DeviceGuard device_guard( + A.get_device_index()); + const cudaStream_t stream = get_current_cuda_stream(A.get_device_index()); - if (out_dtype == at::ScalarType::BFloat16) { + if (out_dtype == torch::headeronly::ScalarType::BFloat16) { return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else if (out_dtype == at::ScalarType::Half) { + } else if (out_dtype == torch::headeronly::ScalarType::Half) { return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { - TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (", - out_dtype, ")"); + STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (", + out_dtype, ")"); } #else - TORCH_CHECK(false, - "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " - "a CUTLASS 3.8 source directory to enable support."); + STD_TORCH_CHECK(false, + "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " + "a CUTLASS 3.8 source directory to enable support."); #endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) -} \ No newline at end of file +} diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/libtorch_stable/quantization/fp4/nvfp4_utils.cuh similarity index 99% rename from csrc/quantization/fp4/nvfp4_utils.cuh rename to csrc/libtorch_stable/quantization/fp4/nvfp4_utils.cuh index 0c04f0108..590e4c06b 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_utils.cuh @@ -20,7 +20,7 @@ #include #include -#include "../../cuda_vec_utils.cuh" +#include "cuda_vec_utils.cuh" #if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \ CUDA_VERSION >= 12090 diff --git a/csrc/libtorch_stable/torch_bindings.cpp b/csrc/libtorch_stable/torch_bindings.cpp index c004c6786..c31844948 100644 --- a/csrc/libtorch_stable/torch_bindings.cpp +++ b/csrc/libtorch_stable/torch_bindings.cpp @@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) { ops.def( "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " "bool"); + + // CUTLASS nvfp4 block scaled GEMM + ops.def( + "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," + " Tensor block_scale_a, Tensor block_scale_b," + " Tensor alpha) -> ()"); + + // cutlass nvfp4 block scaled group GEMM + ops.def( + "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," + " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," + " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()"); + + // Compute NVFP4 block quantized tensor. + ops.def( + "scaled_fp4_quant(Tensor input," + " Tensor input_scale, bool " + "is_sf_swizzled_layout) -> (Tensor, Tensor)"); + + // Out variant + // TODO: Add out_variant tag once PyTorch supports it (added in 2.11) + // This registration is now migrated to stable ABI + // at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not + // yet in torch/headeronly), the tag should be applied from Python + // via torch.library.Library.define(..., tags=(torch.Tag.out_variant,)) + // with the .impl remaining in C++. + // See pytorch/pytorch#176117. + ops.def( + "scaled_fp4_quant.out(Tensor input," + " Tensor input_scale, bool " + "is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) " + "-> ()"); + + // Compute NVFP4 experts quantization. + ops.def( + "scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts) -> ()"); + + // Fused SiLU+Mul+NVFP4 experts quantization. + ops.def( + "silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! " + "output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts) -> ()"); + + // Fused SiLU+Mul+NVFP4 quantization. + ops.def( + "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " + "Tensor input, Tensor input_global_scale) -> ()"); + + // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices + // of the given capability + ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); + + // CUTLASS w4a8 GEMM + ops.def( + "cutlass_w4a8_mm(" + " Tensor A," + " Tensor B," + " Tensor group_scales," + " int group_size," + " Tensor channel_scales," + " Tensor token_scales," + " ScalarType? out_type," + " str? maybe_schedule" + ") -> Tensor"); + + // pack scales + ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); + + // encode and reorder weight matrix + ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); + + // CUTLASS w4a8 grouped GEMM + ops.def( + "cutlass_w4a8_moe_mm(" + " Tensor! out_tensors," + " Tensor a_tensors," + " Tensor b_tensors," + " Tensor a_scales," + " Tensor b_scales," + " Tensor b_group_scales," + " int b_group_size," + " Tensor expert_offsets," + " Tensor problem_sizes," + " Tensor a_strides," + " Tensor b_strides," + " Tensor c_strides," + " Tensor group_scale_strides," + " str? maybe_schedule" + ") -> ()"); + + ops.def( + "cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, " + "Tensor)"); #endif } @@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) { TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets)); ops.impl("get_cutlass_batched_moe_mm_data", TORCH_BOX(&get_cutlass_batched_moe_mm_data)); + + // FP4/NVFP4 ops + ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm)); + ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func)); + ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out)); + ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant)); + ops.impl("silu_and_mul_scaled_fp4_experts_quant", + TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant)); + ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant)); + + // W4A8 ops: impl registrations are in the source files + // (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu) #endif } @@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) { TORCH_BOX(&cutlass_group_gemm_supported)); ops.impl("cutlass_scaled_mm_supports_block_fp8", TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8)); + ops.impl("cutlass_scaled_mm_supports_fp4", + TORCH_BOX(&cutlass_scaled_mm_supports_fp4)); #endif } diff --git a/csrc/libtorch_stable/torch_utils.h b/csrc/libtorch_stable/torch_utils.h index 19ad6a977..f5a80d63e 100644 --- a/csrc/libtorch_stable/torch_utils.h +++ b/csrc/libtorch_stable/torch_utils.h @@ -2,6 +2,7 @@ #include #include +#include #include #include diff --git a/csrc/ops.h b/csrc/ops.h index 580fdfc6b..1fdd77f73 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -152,12 +152,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -#ifndef USE_ROCM -void silu_and_mul_nvfp4_quant(torch::Tensor& out, - torch::Tensor& output_block_scale, - torch::Tensor& input, - torch::Tensor& input_global_scale); -#endif void persistent_masked_m_silu_mul_quant( const at::Tensor& input, // (E, T, 2*H) const at::Tensor& counts, // (E) @@ -225,44 +219,6 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W, int64_t ggml_moe_get_block_size(int64_t type); -#ifndef USE_ROCM - -bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability); - -void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, - torch::Tensor const& B, torch::Tensor const& A_sf, - torch::Tensor const& B_sf, - torch::Tensor const& alpha); - -void cutlass_fp4_group_mm( - torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, - const torch::Tensor& alphas, const torch::Tensor& problem_sizes, - const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets); - -std::tuple scaled_fp4_quant_func( - torch::Tensor const& input, torch::Tensor const& input_scale, - bool is_sf_swizzled_layout); - -void scaled_fp4_quant_out(torch::Tensor const& input, - torch::Tensor const& input_scale, - bool is_sf_swizzled_layout, torch::Tensor& output, - torch::Tensor& output_scale); - -void scaled_fp4_experts_quant( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts); - -void silu_and_mul_scaled_fp4_experts_quant( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts); - -#endif - void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale, std::optional const& azp); diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu deleted file mode 100644 index 8729b01e4..000000000 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include "cutlass_extensions/common.hpp" -#include "nvfp4_utils.cuh" - -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) -void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, - torch::Tensor const& input, - torch::Tensor const& output_sf, - torch::Tensor const& input_sf, - bool is_sf_swizzled_layout); -#endif - -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) -void scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts); -#endif - -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) -void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, - torch::Tensor& output_sf, - torch::Tensor& input, - torch::Tensor& input_sf); -#endif - -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) -void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts); -#endif - -static bool nvfp4_quant_sm_supported() { - const int32_t sm = get_sm_version_num(); -#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 - if (sm >= 100 && sm < 120) return true; -#endif -#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 - if (sm >= 120 && sm < 130) return true; -#endif - return false; -} - -void scaled_fp4_quant_out(torch::Tensor const& input, - torch::Tensor const& input_sf, - bool is_sf_swizzled_layout, torch::Tensor& output, - torch::Tensor& output_sf) { -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) - TORCH_CHECK(nvfp4_quant_sm_supported(), - "No compiled nvfp4 quantization kernel for SM ", - get_sm_version_num(), - ". Recompile with the appropriate CUDA arch."); - return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf, - is_sf_swizzled_layout); -#endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); -} - -std::tuple scaled_fp4_quant_func( - torch::Tensor const& input, torch::Tensor const& input_sf, - bool is_sf_swizzled_layout) { - int64_t n = input.size(-1); - int64_t m = input.numel() / n; - auto device = input.device(); - - // Two fp4 values packed into a uint8 - auto output = torch::empty( - {m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8)); - - torch::Tensor output_sf; - if (is_sf_swizzled_layout) { - auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n); - output_sf = torch::empty( - {sf_m, sf_n}, - torch::TensorOptions().device(device).dtype(torch::kInt32)); - } else { - output_sf = torch::empty( - {m, n / CVT_FP4_SF_VEC_SIZE}, - torch::TensorOptions().device(device).dtype(torch::kUInt8)); - } - - scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output, - output_sf); - return {output, output_sf}; -} - -void scaled_fp4_experts_quant( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) - TORCH_CHECK(nvfp4_quant_sm_supported(), - "No compiled nvfp4 experts quantization kernel for SM ", - get_sm_version_num(), - ". Recompile with the appropriate CUDA arch."); - return scaled_fp4_experts_quant_sm1xxa( - output, output_scale, input, input_global_scale, input_offset_by_experts, - output_scale_offset_by_experts); -#endif - TORCH_CHECK_NOT_IMPLEMENTED(false, - "No compiled nvfp4 experts quantization kernel"); -} - -void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf, - torch::Tensor& input, torch::Tensor& input_sf) { -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) - TORCH_CHECK(nvfp4_quant_sm_supported(), - "No compiled silu_and_mul nvfp4 quantization kernel for SM ", - get_sm_version_num(), - ". Recompile with the appropriate CUDA arch."); - return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf); -#endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, "No compiled silu_and_mul nvfp4 quantization kernel"); -} - -void silu_and_mul_scaled_fp4_experts_quant( - torch::Tensor& output, torch::Tensor& output_scale, - torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts) { -#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ - (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) - TORCH_CHECK(nvfp4_quant_sm_supported(), - "No compiled silu_and_mul nvfp4 experts quantization kernel " - "for SM ", - get_sm_version_num(), - ". Recompile with the appropriate CUDA arch."); - return silu_and_mul_scaled_fp4_experts_quant_sm1xxa( - output, output_scale, input, input_global_scale, input_offset_by_experts, - output_scale_offset_by_experts); -#endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, "No compiled silu_and_mul nvfp4 experts quantization kernel"); -} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b7ab51c1d..85605458f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -109,13 +109,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); -#ifndef USE_ROCM - ops.def( - "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " - "Tensor input, Tensor input_global_scale) -> ()"); - ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant); -#endif - ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); @@ -332,47 +325,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? qzeros_or_none, bool inplace) -> Tensor"); // conditionally compiled so impl registrations are in source file - // CUTLASS w4a8 GEMM - ops.def( - "cutlass_w4a8_mm(" - " Tensor A," - " Tensor B," - " Tensor group_scales," - " int group_size," - " Tensor channel_scales," - " Tensor token_scales," - " ScalarType? out_type," - " str? maybe_schedule" - ") -> Tensor"); - // pack scales - ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); - // encode and reorder weight matrix - ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); - // conditionally compiled so impl registration is in source file - - // CUTLASS w4a8 grouped GEMM - ops.def( - "cutlass_w4a8_moe_mm(" - " Tensor! out_tensors," - " Tensor a_tensors," - " Tensor b_tensors," - " Tensor a_scales," - " Tensor b_scales," - " Tensor b_group_scales," - " int b_group_size," - " Tensor expert_offsets," - " Tensor problem_sizes," - " Tensor a_strides," - " Tensor b_strides," - " Tensor c_strides," - " Tensor group_scale_strides," - " str? maybe_schedule" - ") -> ()"); - ops.def( - "cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, " - "Tensor)"); - // conditionally compiled so impl registration is in source file - #endif // Dequantization for GGML. @@ -409,20 +361,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM - // CUTLASS nvfp4 block scaled GEMM - ops.def( - "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," - " Tensor block_scale_a, Tensor block_scale_b," - " Tensor alpha) -> ()"); - ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); - - // cutlass nvfp4 block scaled group GEMM - ops.def( - "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," - " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," - " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()"); - // conditionally compiled so impl registration is in source file - // Expert-specialization mxfp8 blockscaled grouped quantization (SM100+). ops.def( "mxfp8_experts_quant(" @@ -455,44 +393,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "-> int"); // conditionally compiled so impl in source file - // Compute NVFP4 block quantized tensor. - ops.def( - "scaled_fp4_quant(Tensor input," - " Tensor input_scale, bool " - "is_sf_swizzled_layout) -> (Tensor, Tensor)"); - ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func); - - // Out variant - // TODO: Add {at::Tag::out_variant} tag and update all call sites - // to use the functional variant once vLLM upgrades PyTorch. - // See pytorch/pytorch#176117. - ops.def( - "scaled_fp4_quant.out(Tensor input," - " Tensor input_scale, bool " - "is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) " - "-> ()"); - ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out); - - // Compute NVFP4 experts quantization. - ops.def( - "scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," - "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," - "Tensor output_scale_offset_by_experts) -> ()"); - ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); - - // Fused SiLU+Mul+NVFP4 experts quantization. - ops.def( - "silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! " - "output_scale," - "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," - "Tensor output_scale_offset_by_experts) -> ()"); - ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, - &silu_and_mul_scaled_fp4_experts_quant); - - // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices - // of the given capability - ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); - ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4); #endif // Quantized GEMM for GPTQ.