[4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
mikaylagawarecki
2026-03-31 13:21:13 -04:00
committed by GitHub
parent 0dd25a44ea
commit 7c080dd3c5
27 changed files with 1205 additions and 1016 deletions

View File

@@ -340,8 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp") "csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
@@ -489,59 +487,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures") " in CUDA target architectures")
endif() endif()
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# CUTLASS MLA Archs and flags # CUTLASS MLA Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
@@ -681,34 +626,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
endif() endif()
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
# Hadacore kernels # Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}") cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
@@ -760,7 +677,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(VLLM_STABLE_EXT_SRC set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp" "csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp" "csrc/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu") "csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC list(APPEND VLLM_STABLE_EXT_SRC
@@ -978,6 +898,96 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
endif() endif()
#
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
#
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
#
# W4A8 kernels (moved from _C to _C_stable_libtorch)
#
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
message(STATUS "Enabling C_stable extension.") message(STATUS "Enabling C_stable extension.")
define_extension_target( define_extension_target(
_C_stable_libtorch _C_stable_libtorch

View File

@@ -3,8 +3,8 @@
#pragma once #pragma once
#include <c10/util/BFloat16.h> #include <torch/headeronly/util/BFloat16.h>
#include <c10/util/Half.h> #include <torch/headeronly/util/Half.h>
#include <cassert> #include <cassert>
#ifdef USE_ROCM #ifdef USE_ROCM

View File

@@ -1,7 +1,6 @@
#pragma once #pragma once
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute { namespace cute {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////

View File

@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
} }
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) { for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
begin_loop(int epi_m, int epi_n) { begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt); copy(tSR_sRow_flt, tSR_rRow_flt);
} }
} }
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
auto [m, n, k, l] = args.tile_coord_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy)); using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem), cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem //// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
Stride<_0, _1>>{}, Stride<_0, _1>>{},
Layout<_1>{}); Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow); cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow); cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord //// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow); cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg //// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>( return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow, tGS_gRow,
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE void CUTLASS_DEVICE void
begin() { begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol)); cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) { for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m; pred(i) = get<0>(tCcCol(i)) < m;
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE Array<Element, FragmentSize> CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col; Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) { for (int i = 0; i < FragmentSize; ++i) {
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
auto [M, N, K, L] = args.problem_shape_mnkl; auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and // Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate // partition the same way, this will be used to generate the predicate
// tensor for loading // tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape()); cute::Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks( return ConsumerStoreCallbacks(

View File

@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
} }
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) { for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
begin_loop(int epi_m, int epi_n) { begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt); copy(tSR_sRow_flt, tSR_rRow_flt);
} }
} }
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
auto [m, n, k, l] = args.tile_coord_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy)); using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem), cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem //// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
Stride<_0, _1>>{}, Stride<_0, _1>>{},
Layout<_1>{}); Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow); cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow); cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord //// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow); cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg //// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>( return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow, tGS_gRow,
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE void CUTLASS_DEVICE void
begin() { begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol)); cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) { for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m; pred(i) = get<0>(tCcCol(i)) < m;
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE Array<Element, FragmentSize> CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) { visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col; Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) { for (int i = 0; i < FragmentSize; ++i) {
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) { get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl; auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and // Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate // partition the same way, this will be used to generate the predicate
// tensor for loading // tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape()); cute::Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks( return ConsumerStoreCallbacks(

View File

@@ -1,6 +1,21 @@
#pragma once #pragma once
#include <torch/all.h> // This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "cutlass/layout/matrix.h" #include "cutlass/layout/matrix.h"
@@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra // If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1. // strides are set to be 0 or 1.
template <typename Stride> template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor, static inline auto make_cute_layout(TorchTensor const& tensor,
std::string_view name = "tensor") { std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{})); TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx( auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
Stride{}, [&](auto const& stride_ele, auto const& idx) { auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>; using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) { if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) { if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value); name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{}; return StrideEle{};
} else { } else {
if (tensor.size(idx) == 1) { if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for // use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims) // cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0}; return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
} else { } else {
// Extra strides are assumed to be 0 or 1 return tensor.stride(idx);
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
} }
}); }
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) { auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim()) if (idx < tensor.dim())
@@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
template <typename Stride> template <typename Stride>
static inline auto maybe_make_cute_layout( static inline auto maybe_make_cute_layout(
std::optional<torch::Tensor> const& tensor, std::optional<TorchTensor> const& tensor,
std::string_view name = "tensor") { std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor)); using Layout = decltype(make_cute_layout<Stride>(*tensor));
@@ -121,12 +136,12 @@ template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type; using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <> template <>
struct equivalent_cutlass_type<c10::Half> { struct equivalent_cutlass_type<torch::headeronly::Half> {
using type = cutlass::half_t; using type = cutlass::half_t;
}; };
template <> template <>
struct equivalent_cutlass_type<c10::BFloat16> { struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
using type = cutlass::bfloat16_t; using type = cutlass::bfloat16_t;
}; };
@@ -134,8 +149,8 @@ struct equivalent_cutlass_type<c10::BFloat16> {
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) // equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
// //
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from // Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` // get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
template <typename T> template <typename T>
struct equivalent_scalar_type { struct equivalent_scalar_type {
using type = T; using type = T;
@@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <> template <>
struct equivalent_scalar_type<cutlass::half_t> { struct equivalent_scalar_type<cutlass::half_t> {
using type = c10::Half; using type = torch::headeronly::Half;
}; };
template <> template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> { struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = c10::BFloat16; using type = torch::headeronly::BFloat16;
}; };
// get equivalent c10::ScalarType tag from compile time type // get equivalent torch::headeronly::ScalarType tag from compile time type
template <typename T> template <typename T>
static inline constexpr c10::ScalarType equivalent_scalar_type_v = static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value; torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;

View File

@@ -49,6 +49,15 @@
THO_DISPATCH_SWITCH(TYPE, NAME, \ THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
// Half types dispatch (Half + BFloat16)
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
// Boolean dispatch // Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \ #define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \ if (expr) { \

View File

@@ -84,4 +84,54 @@ void get_cutlass_batched_moe_mm_data(
const torch::stable::Tensor& expert_num_tokens, const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n, const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k); const int64_t k);
// FP4/NVFP4 ops
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets);
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& output_block_scale,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
#endif #endif

View File

@@ -2,10 +2,9 @@
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <c10/cuda/CUDAStream.h> #include "libtorch_stable/torch_utils.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h" #include "cutlass/bfloat16.h"
#include "cutlass/float8.h" #include "cutlass/float8.h"
@@ -41,7 +40,7 @@ __global__ void get_group_gemm_starts(
} }
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \ #define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \ get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \ cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \ <<<1, num_experts, 0, stream>>>( \
@@ -66,23 +65,34 @@ __global__ void get_group_gemm_starts(
namespace { namespace {
void run_get_group_gemm_starts( void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs, torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs, torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs, torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors, torch::stable::Tensor& b_group_scales_ptrs,
torch::Tensor const& b_tensors, torch::Tensor& out_tensors, torch::stable::Tensor const& a_tensors,
torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
torch::Tensor const& b_group_scales, const int64_t b_group_size) { torch::stable::Tensor const& a_scales,
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); torch::stable::Tensor const& b_scales,
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32 torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32); STD_TORCH_CHECK(a_tensors.scalar_type() ==
TORCH_CHECK(b_scales.dtype() == torch::kFloat32); torch::headeronly::ScalarType::Float8_e4m3fn);
TORCH_CHECK(b_group_scales.dtype() == STD_TORCH_CHECK(
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3 b_tensors.scalar_type() ==
TORCH_CHECK(out_tensors.dtype() == torch::headeronly::ScalarType::Int); // int4 8x packed into int32
torch::kBFloat16); // only support bf16 for now STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(
b_group_scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
// type is e4m3
STD_TORCH_CHECK(
out_tensors.scalar_type() ==
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
// expect int64_t to avoid overflow during offset calculations // expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); STD_TORCH_CHECK(expert_offsets.scalar_type() ==
torch::headeronly::ScalarType::Long);
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
// logical k, n // logical k, n
@@ -90,15 +100,16 @@ void run_get_group_gemm_starts(
int64_t k = a_tensors.size(1); int64_t k = a_tensors.size(1);
int64_t scale_k = cutlass::ceil_div(k, b_group_size); int64_t scale_k = cutlass::ceil_div(k, b_group_size);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = get_current_cuda_stream(a_tensors.get_device_index());
if (false) { if (false) {
} }
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t) __CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half) cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
else { else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
} }
} }
} // namespace } // namespace

View File

@@ -14,13 +14,12 @@
#include "cutlass/util/mixed_dtype_utils.hpp" #include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes // vllm includes
#include <ATen/cuda/CUDAContext.h> #include <torch/csrc/stable/library.h>
#include <c10/cuda/CUDAGuard.h> #include <torch/csrc/stable/tensor.h>
#include <torch/all.h> #include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
#include "core/registration.h"
#include "get_group_starts.cuh" #include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "w4a8_utils.cuh" #include "w4a8_utils.cuh"
@@ -168,31 +167,40 @@ struct W4A8GroupedGemmKernel {
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0, static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
"LayoutB_Reordered size must be divisible by 4 bytes"); "LayoutB_Reordered size must be divisible by 4 bytes");
static void grouped_mm( static void grouped_mm(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, const torch::Tensor& a_tensors, const torch::stable::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales, const torch::stable::Tensor& b_tensors,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, const torch::stable::Tensor& a_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets, const torch::stable::Tensor& b_scales,
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides, const torch::stable::Tensor& b_group_scales,
const torch::Tensor& b_strides, const torch::Tensor& c_strides, const int64_t b_group_size,
const torch::Tensor& group_scale_strides) { const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes_torch,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides) {
auto device = a_tensors.device(); auto device = a_tensors.device();
auto device_id = device.index(); auto device_id = device.index();
const at::cuda::OptionalCUDAGuard device_guard(device); const torch::stable::accelerator::DeviceGuard device_guard(device_id);
auto stream = at::cuda::getCurrentCUDAStream(device_id); auto stream = get_current_cuda_stream(device_id);
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
int n = static_cast<int>(b_tensors.size(1)); int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
auto options_int = torch::stable::Tensor a_ptrs = torch::stable::empty(
torch::TensorOptions().dtype(torch::kInt64).device(device); num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::Tensor a_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor b_ptrs = torch::stable::empty(
torch::Tensor b_ptrs = torch::empty(num_experts, options_int); num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor out_ptrs = torch::stable::empty(
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int); num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
// get the correct offsets to pass to gemm // get the correct offsets to pass to gemm
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs, run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
@@ -247,9 +255,9 @@ struct W4A8GroupedGemmKernel {
// Allocate workspace // Allocate workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments); size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace = torch::stable::Tensor workspace = torch::stable::empty(
torch::empty(workspace_size, workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
torch::TensorOptions().dtype(torch::kU8).device(device)); device);
// Run GEMM // Run GEMM
GemmShuffled gemm; GemmShuffled gemm;
@@ -294,14 +302,20 @@ using Kernel_256x128_2x1x1_Coop =
using Kernel_128x256_2x1x1_Coop = using Kernel_128x256_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>; W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
void mm_dispatch( void mm_dispatch(torch::stable::Tensor& out_tensors,
torch::Tensor& out_tensors, const torch::Tensor& a_tensors, const torch::stable::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales, const torch::stable::Tensor& b_tensors,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, const torch::stable::Tensor& a_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets, const torch::stable::Tensor& b_scales,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, const torch::stable::Tensor& b_group_scales,
const torch::Tensor& b_strides, const torch::Tensor& c_strides, const int64_t b_group_size,
const torch::Tensor& group_scale_strides, const std::string& schedule) { const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
const std::string& schedule) {
if (schedule == "Kernel_128x16_1x1x1_Coop") { if (schedule == "Kernel_128x16_1x1x1_Coop") {
Kernel_128x16_1x1x1_Coop::grouped_mm( Kernel_128x16_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales, out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
@@ -358,18 +372,23 @@ void mm_dispatch(
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides, b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides); c_strides, group_scale_strides);
} else { } else {
TORCH_CHECK(false, STD_TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule); "cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
} }
} }
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors, void mm(torch::stable::Tensor& out_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales, const torch::stable::Tensor& a_tensors,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales, const torch::stable::Tensor& b_tensors,
const int64_t b_group_size, const torch::Tensor& expert_offsets, const torch::stable::Tensor& a_scales,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides, const torch::stable::Tensor& b_scales,
const torch::Tensor& b_strides, const torch::Tensor& c_strides, const torch::stable::Tensor& b_group_scales, const int64_t b_group_size,
const torch::Tensor& group_scale_strides, const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
std::optional<std::string> maybe_schedule) { std::optional<std::string> maybe_schedule) {
// user has specified a schedule // user has specified a schedule
if (maybe_schedule) { if (maybe_schedule) {
@@ -406,26 +425,27 @@ void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
a_strides, b_strides, c_strides, group_scale_strides, schedule); a_strides, b_strides, c_strides, group_scale_strides, schedule);
} }
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b( std::tuple<torch::stable::Tensor, torch::stable::Tensor>
torch::Tensor const& b_tensors) { encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); STD_TORCH_CHECK(b_tensors.scalar_type() ==
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k) torch::headeronly::ScalarType::Int);
TORCH_CHECK(b_tensors.is_contiguous()); STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
TORCH_CHECK(b_tensors.is_cuda()); STD_TORCH_CHECK(b_tensors.is_contiguous());
STD_TORCH_CHECK(b_tensors.is_cuda());
int n = static_cast<int>(b_tensors.size(1)); int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0. // CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
// These misalignments cause silent OOB unless run under Compute Sanitizer. // These misalignments cause silent OOB unless run under Compute Sanitizer.
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256"); STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16"); STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
// we will store the layout to an int32 tensor; // we will store the layout to an int32 tensor;
// this is the number of elements we need per layout // this is the number of elements we need per layout
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t); constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors); torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors);
int num_experts = static_cast<int>(b_tensors.size(0)); int num_experts = static_cast<int>(b_tensors.size(0));
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr()); auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
@@ -435,7 +455,7 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
size_t num_int4_elems = 1ull * num_experts * n * k; size_t num_int4_elems = 1ull * num_experts * n * k;
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr, bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
num_int4_elems); num_int4_elems);
TORCH_CHECK(ok, "unified_encode_int4b failed"); STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
// construct the layout once; assumes each expert has the same layout // construct the layout once; assumes each expert has the same layout
using LayoutType = LayoutB_Reordered; using LayoutType = LayoutB_Reordered;
@@ -456,28 +476,28 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
} }
// save the packed layout to torch tensor so we can re-use it // save the packed layout to torch tensor so we can re-use it
auto cpu_opts = torch::stable::Tensor layout_cpu = torch::stable::empty(
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); {num_experts, layout_width}, torch::headeronly::ScalarType::Int,
torch::Tensor layout_cpu = std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU));
torch::empty({num_experts, layout_width}, cpu_opts);
int32_t* layout_data = layout_cpu.data_ptr<int32_t>(); int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>();
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
std::memcpy(layout_data + i * layout_width, // dst (int32*) std::memcpy(layout_data + i * layout_width, // dst (int32*)
&layout_B_reordered, // src (LayoutType*) &layout_B_reordered, // src (LayoutType*)
sizeof(LayoutType)); // number of bytes sizeof(LayoutType)); // number of bytes
} }
torch::Tensor packed_layout = torch::stable::Tensor packed_layout =
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false); torch::stable::to(layout_cpu, b_tensors.device(),
/*non_blocking=*/false);
return {b_tensors_packed, packed_layout}; return {b_tensors_packed, packed_layout};
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", &mm); m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm));
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b); m.impl("cutlass_encode_and_reorder_int4b_grouped",
TORCH_BOX(&encode_and_reorder_int4b));
} }
} // namespace vllm::cutlass_w4a8_moe } // namespace vllm::cutlass_w4a8_moe
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -3,14 +3,12 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu // https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
// //
#include <ATen/cuda/CUDAContext.h> #include <torch/csrc/stable/library.h>
#include <c10/cuda/CUDAGuard.h> #include <torch/csrc/stable/tensor.h>
#include <torch/all.h> #include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp" #include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh" #include "w4a8_utils.cuh"
#include "core/registration.h"
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include <limits> #include <limits>
@@ -161,31 +159,31 @@ struct W4A8GemmKernel {
using StrideD = typename GemmKernelShuffled::StrideD; using StrideD = typename GemmKernelShuffled::StrideD;
using StrideS = typename CollectiveMainloopShuffled::StrideScale; using StrideS = typename CollectiveMainloopShuffled::StrideScale;
static torch::Tensor mm(torch::Tensor const& A, static torch::stable::Tensor mm(
torch::Tensor const& B, // already packed torch::stable::Tensor const& A,
torch::Tensor const& group_scales, // already packed torch::stable::Tensor const& B, // already packed
int64_t group_size, torch::stable::Tensor const& group_scales, // already packed
torch::Tensor const& channel_scales, int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::Tensor const& token_scales, torch::stable::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type) { std::optional<torch::headeronly::ScalarType> const& maybe_out_type) {
// TODO: param validation // TODO: param validation
int m = A.size(0); int m = A.size(0);
int k = A.size(1); int k = A.size(1);
int n = B.size(1); int n = B.size(1);
// safely cast group_size to int // safely cast group_size to int
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(), STD_TORCH_CHECK(
"group_size out of supported range for int: ", group_size); group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
int const group_size_int = static_cast<int>(group_size); int const group_size_int = static_cast<int>(group_size);
// Allocate output // Allocate output
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
auto device = A.device(); auto device = A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index()); auto stream = get_current_cuda_stream(device.index());
torch::Tensor D = torch::stable::Tensor D = torch::stable::empty(
torch::empty({m, n}, torch::TensorOptions() {m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device);
.dtype(equivalent_scalar_type_v<ElementD>)
.device(device));
// prepare arg pointers // prepare arg pointers
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr()); auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr()); auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
@@ -237,9 +235,9 @@ struct W4A8GemmKernel {
// Workspace // Workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments); size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace = torch::stable::Tensor workspace = torch::stable::empty(
torch::empty(workspace_size, workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
torch::TensorOptions().dtype(torch::kU8).device(device)); device);
// Run GEMM // Run GEMM
GemmShuffled gemm; GemmShuffled gemm;
@@ -269,14 +267,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>; using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>; using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
torch::Tensor mm_dispatch(torch::Tensor const& A, torch::stable::Tensor mm_dispatch(
torch::Tensor const& B, // already packed torch::stable::Tensor const& A,
torch::Tensor const& group_scales, // already packed torch::stable::Tensor const& B, // already packed
int64_t group_size, torch::stable::Tensor const& group_scales, // already packed
torch::Tensor const& channel_scales, int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::Tensor const& token_scales, torch::stable::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type, std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
const std::string& schedule) { const std::string& schedule) {
if (schedule == "256x128_1x1x1") { if (schedule == "256x128_1x1x1") {
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
channel_scales, token_scales, channel_scales, token_scales,
@@ -318,17 +316,18 @@ torch::Tensor mm_dispatch(torch::Tensor const& A,
channel_scales, token_scales, channel_scales, token_scales,
maybe_out_type); maybe_out_type);
} }
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
return {}; return {};
} }
torch::Tensor mm(torch::Tensor const& A, torch::stable::Tensor mm(
torch::Tensor const& B, // already packed torch::stable::Tensor const& A,
torch::Tensor const& group_scales, // already packed torch::stable::Tensor const& B, // already packed
int64_t group_size, torch::Tensor const& channel_scales, torch::stable::Tensor const& group_scales, // already packed
torch::Tensor const& token_scales, int64_t group_size, torch::stable::Tensor const& channel_scales,
std::optional<at::ScalarType> const& maybe_out_type, torch::stable::Tensor const& token_scales,
std::optional<std::string> maybe_schedule) { std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
// requested a specific schedule // requested a specific schedule
if (maybe_schedule) { if (maybe_schedule) {
return mm_dispatch(A, B, group_scales, group_size, channel_scales, return mm_dispatch(A, B, group_scales, group_size, channel_scales,
@@ -378,14 +377,15 @@ torch::Tensor mm(torch::Tensor const& A,
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Pre-processing utils // Pre-processing utils
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); STD_TORCH_CHECK(scales.scalar_type() ==
TORCH_CHECK(scales.is_contiguous()); torch::headeronly::ScalarType::Float8_e4m3fn);
TORCH_CHECK(scales.is_cuda()); STD_TORCH_CHECK(scales.is_contiguous());
STD_TORCH_CHECK(scales.is_cuda());
auto packed_scales = torch::empty( auto packed_scales =
{scales.numel() * ScalePackSize}, torch::stable::empty({scales.numel() * ScalePackSize},
torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); scales.scalar_type(), std::nullopt, scales.device());
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr()); auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
auto packed_scales_ptr = auto packed_scales_ptr =
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>( static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
@@ -396,15 +396,16 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return packed_scales; return packed_scales;
} }
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
TORCH_CHECK(B.dtype() == torch::kInt32); STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int);
TORCH_CHECK(B.dim() == 2); STD_TORCH_CHECK(B.dim() == 2);
torch::Tensor B_packed = torch::empty_like(B); torch::stable::Tensor B_packed = torch::stable::empty_like(B);
int k = B.size(0) * PackFactor; // logical k int k = B.size(0) * PackFactor; // logical k
int n = B.size(1); int n = B.size(1);
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks"); STD_TORCH_CHECK((n * k) % 32 == 0,
"need multiples of 32 int4s for 16B chunks");
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr()); auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr()); auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
@@ -415,16 +416,17 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr, bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
n * k); n * k);
TORCH_CHECK(ok, "unified_encode_int4b failed"); STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
return B_packed; return B_packed;
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_mm", &mm); m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm));
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8));
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); m.impl("cutlass_encode_and_reorder_int4b",
TORCH_BOX(&encode_and_reorder_int4b));
} }
} // namespace vllm::cutlass_w4a8 } // namespace vllm::cutlass_w4a8

View File

@@ -14,16 +14,15 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "launch_bounds_utils.h" #include "launch_bounds_utils.h"
@@ -118,17 +117,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm } // namespace vllm
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] void silu_and_mul_nvfp4_quant_sm1xxa(
torch::Tensor& output_sf, torch::stable::Tensor& output, // [..., d]
torch::Tensor& input, // [..., 2 * d] torch::stable::Tensor& output_sf,
torch::Tensor& input_sf) { torch::stable::Tensor& input, // [..., 2 * d]
torch::stable::Tensor& input_sf) {
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1) / 2; int32_t n = input.size(1) / 2;
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || STD_TORCH_CHECK(
input.scalar_type() == at::ScalarType::BFloat16, input.scalar_type() == torch::headeronly::ScalarType::Half ||
"Unsupported input data type for quantize_to_fp4."); input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount = int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1); get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -136,8 +137,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr()); auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr()); auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr()); auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const torch::stable::accelerator::DeviceGuard device_guard(
auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM = int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x)); vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
@@ -149,7 +151,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES( VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());

View File

@@ -14,14 +14,12 @@
* limitations under the License. * limitations under the License.
*/ */
#include "core/registration.h" #include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h> #include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
@@ -122,7 +120,7 @@ __global__ void __get_group_gemm_starts(
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \ #define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \ TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \ LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \ __get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \ LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \ <<<1, num_experts, 0, stream>>>( \
@@ -150,50 +148,64 @@ __global__ void __get_group_gemm_starts(
} }
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig> template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts( void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
const torch::Tensor& a_starts, const torch::Tensor& b_starts, const torch::stable::Tensor& b_starts,
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts, const torch::stable::Tensor& out_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts, const torch::stable::Tensor& a_scales_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, const torch::stable::Tensor& b_scales_starts,
const torch::Tensor& a_strides, const torch::Tensor& b_strides, const torch::stable::Tensor& alpha_starts,
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val, const torch::stable::Tensor& layout_sfa,
int64_t c_stride_val, const torch::stable::Tensor& layout_sfb,
/*these are used for their base addresses*/ const torch::stable::Tensor& a_strides,
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, const torch::stable::Tensor& b_strides,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales, const torch::stable::Tensor& c_strides,
torch::Tensor const& b_scales, torch::Tensor const& alphas, int64_t a_stride_val, int64_t b_stride_val,
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets, int64_t c_stride_val,
torch::Tensor const& problem_sizes, int M, int N, int K) { /*these are used for their base addresses*/
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& alphas,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& sf_offsets,
torch::stable::Tensor const& problem_sizes,
int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0); int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = get_current_cuda_stream(a_tensors.get_device_index());
TORCH_CHECK(out_tensors.size(1) == N, STD_TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape"); "Output tensor shape doesn't match expected shape");
TORCH_CHECK(K / 2 == b_tensors.size(2), STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing" "b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match"); " dimension must match");
if (false) { if (false) {
} }
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig) // ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE( __CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16, cutlass::float_e2m1_t, cutlass::float_ue4m3_t,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA,
LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t, __CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t, torch::kFloat16, cutlass::float_ue4m3_t,
half, LayoutSFA, LayoutSFB, ScaleConfig) torch::headeronly::ScalarType::Half, half,
LayoutSFA, LayoutSFB, ScaleConfig)
else { else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
} }
} }
template <typename OutType> template <typename OutType>
void run_fp4_blockwise_scaled_group_mm_sm100( void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::stable::Tensor& b_blockscales,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, const torch::stable::Tensor& alphas,
int N, int K) { const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape = using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>; cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t; using ElementType = cutlass::float_e2m1_t;
@@ -272,20 +284,40 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor a_ptrs =
torch::Tensor b_ptrs = torch::empty(num_experts, options_int); torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
torch::Tensor out_ptrs = torch::empty(num_experts, options_int); std::nullopt, a.device());
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor b_ptrs =
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); std::nullopt, a.device());
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); torch::stable::Tensor out_ptrs =
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
torch::Tensor a_strides1 = torch::empty(num_experts, options_int); std::nullopt, a.device());
torch::Tensor b_strides1 = torch::empty(num_experts, options_int); torch::stable::Tensor a_scales_ptrs =
torch::Tensor c_strides1 = torch::empty(num_experts, options_int); torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>( run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -308,7 +340,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler; typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM; scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device(); hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts; static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] = cached_sm_counts[hw_info.device_id] =
@@ -350,32 +382,35 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
scheduler}; scheduler};
size_t workspace_size = Gemm::get_workspace_size(args); size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options = auto workspace =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
auto workspace = torch::empty(workspace_size, workspace_options); std::nullopt, a.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args); auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, STD_TORCH_CHECK(
"Failed to implement GEMM: status=", (int)can_implement_status); can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM // Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr()); auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status, "Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts, " workspace_size=", workspace_size,
" M=", M, " N=", N, " K=", K); " num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream); status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
} }
void run_fp4_blockwise_scaled_group_mm_sm120( void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::stable::Tensor& b_blockscales,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, const torch::stable::Tensor& alphas,
int N, int K) { const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape = using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>; cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t; using ElementType = cutlass::float_e2m1_t;
@@ -446,20 +481,40 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor a_ptrs =
torch::Tensor b_ptrs = torch::empty(num_experts, options_int); torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
torch::Tensor out_ptrs = torch::empty(num_experts, options_int); std::nullopt, a.device());
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); torch::stable::Tensor b_ptrs =
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); std::nullopt, a.device());
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); torch::stable::Tensor out_ptrs =
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
torch::Tensor a_strides1 = torch::empty(num_experts, options_int); std::nullopt, a.device());
torch::Tensor b_strides1 = torch::empty(num_experts, options_int); torch::stable::Tensor a_scales_ptrs =
torch::Tensor c_strides1 = torch::empty(num_experts, options_int); torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>( run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -480,7 +535,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler; typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM; scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device(); hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts; static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] = cached_sm_counts[hw_info.device_id] =
@@ -523,33 +578,36 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
scheduler}; scheduler};
size_t workspace_size = Gemm::get_workspace_size(args); size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options = auto workspace =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
auto workspace = torch::empty(workspace_size, workspace_options); std::nullopt, a.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args); auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, STD_TORCH_CHECK(
"Failed to implement GEMM: status=", (int)can_implement_status); can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM // Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr()); auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status, "Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts, " workspace_size=", workspace_size,
" M=", M, " N=", N, " K=", K); " num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream); status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
} }
template <typename OutType> template <typename OutType>
void run_fp4_blockwise_scaled_group_mm( void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::stable::Tensor& b_blockscales,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, const torch::stable::Tensor& alphas,
int N, int K) { const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
if (version_num >= 120 && version_num < 130) { if (version_num >= 120 && version_num < 130) {
@@ -567,7 +625,7 @@ void run_fp4_blockwise_scaled_group_mm(
return; return;
} }
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED( STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ", "No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100 or 120"); version_num, ". Required capability: 100 or 120");
@@ -575,26 +633,31 @@ void run_fp4_blockwise_scaled_group_mm(
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
#endif #endif
#define CHECK_TYPE(x, st, m) \ #define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \ #define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \ #define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \ CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m) CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm( void cutlass_fp4_group_mm(torch::stable::Tensor& output,
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, const torch::stable::Tensor& a,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, const torch::stable::Tensor& b,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes, const torch::stable::Tensor& a_blockscale,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets) {
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \ #if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120) (defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation // Input validation
@@ -602,30 +665,34 @@ void cutlass_fp4_group_mm(
CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale"); CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales"); CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas"); CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas");
TORCH_CHECK(a_blockscale.dim() == 2, STD_TORCH_CHECK(
"expected a_blockscale to be of shape [num_experts, rounded_m," a_blockscale.dim() == 2,
" k // group_size], observed rank: ", "expected a_blockscale to be of shape [num_experts, rounded_m,"
a_blockscale.dim()) " k // group_size], observed rank: ",
TORCH_CHECK(b_blockscales.dim() == 3, a_blockscale.dim())
"expected b_blockscale to be of shape: " STD_TORCH_CHECK(b_blockscales.dim() == 3,
" [num_experts, n, k // group_size], observed rank: ", "expected b_blockscale to be of shape: "
b_blockscales.dim()) " [num_experts, n, k // group_size], observed rank: ",
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor"); b_blockscales.dim())
TORCH_CHECK(problem_sizes.size(1) == 3, STD_TORCH_CHECK(problem_sizes.dim() == 2,
"problem_sizes must have the shape (num_experts, 3)"); "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"Number of experts in problem_sizes must match expert_offsets"); "problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, STD_TORCH_CHECK(
"problem_sizes must be int32."); problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0)); int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1)); int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0)); int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2)); int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) { if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>( run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K); expert_offsets, sf_offsets, M, N, K);
@@ -633,7 +700,7 @@ void cutlass_fp4_group_mm(
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
int32_t version_num = get_sm_version_num(); int32_t version_num = get_sm_version_num();
if (version_num >= 120 && version_num < 130) { if (version_num >= 120 && version_num < 130) {
TORCH_CHECK_NOT_IMPLEMENTED( STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ", false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
output.scalar_type()); output.scalar_type());
} }
@@ -643,7 +710,7 @@ void cutlass_fp4_group_mm(
expert_offsets, sf_offsets, M, N, K); expert_offsets, sf_offsets, M, N, K);
} }
#else #else
TORCH_CHECK_NOT_IMPLEMENTED( STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must " "No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 " "be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
@@ -651,6 +718,6 @@ void cutlass_fp4_group_mm(
#endif #endif
} }
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm); m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm));
} }

View File

@@ -14,16 +14,15 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "nvfp4_utils.cuh" #include "nvfp4_utils.cuh"
@@ -327,25 +326,28 @@ void quant_impl(void* output, void* output_scale, void* input,
} // namespace vllm } // namespace vllm
/*Quantization entry for fp4 experts quantization*/ /*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") #define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \ #define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = at::ScalarType::Half; constexpr auto HALF = torch::headeronly::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16; constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float; constexpr auto FLOAT = torch::headeronly::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int; constexpr auto INT = torch::headeronly::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte; constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
// Common validation for fp4 experts quantization entry points. // Common validation for fp4 experts quantization entry points.
static void validate_fp4_experts_quant_inputs( static void validate_fp4_experts_quant_inputs(
torch::Tensor const& output, torch::Tensor const& output_scale, torch::stable::Tensor const& output,
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::stable::Tensor const& output_scale,
torch::Tensor const& input_offset_by_experts, torch::stable::Tensor const& input,
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk, torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
int64_t k) { int64_t k) {
CHECK_INPUT(output, "output"); CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale"); CHECK_INPUT(output_scale, "output_scale");
@@ -354,41 +356,42 @@ static void validate_fp4_experts_quant_inputs(
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts"); CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts"); CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
TORCH_CHECK(output.dim() == 2); STD_TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2); STD_TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2); STD_TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1); STD_TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1); STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8) // output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32) // output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8); STD_TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT); STD_TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16; const int BLOCK_SIZE = 16;
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0); auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk); STD_TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2); STD_TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE; int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4. // 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4; int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32 // 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k); STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
} }
void scaled_fp4_experts_quant_sm1xxa( void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale, torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::stable::Tensor const& input,
torch::Tensor const& input_offset_by_experts, torch::stable::Tensor const& input_global_scale,
torch::Tensor const& output_scale_offset_by_experts) { torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0); auto m_topk = input.size(0);
auto k = input.size(1); auto k = input.size(1);
@@ -397,11 +400,11 @@ void scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k); output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0); auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const torch::stable::accelerator::DeviceGuard device_guard(
const cudaStream_t stream = input.get_device_index());
at::cuda::getCurrentCUDAStream(input.get_device()); const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_DISPATCH_HALF_TYPES( VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>( vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
@@ -413,14 +416,15 @@ void scaled_fp4_experts_quant_sm1xxa(
} }
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale, torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::stable::Tensor const& input,
torch::Tensor const& input_offset_by_experts, torch::stable::Tensor const& input_global_scale,
torch::Tensor const& output_scale_offset_by_experts) { torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0); auto m_topk = input.size(0);
// Input has gate || up layout, so k = input.size(1) / 2 // Input has gate || up layout, so k = input.size(1) / 2
auto k_times_2 = input.size(1); auto k_times_2 = input.size(1);
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)"); STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2; auto k = k_times_2 / 2;
validate_fp4_experts_quant_inputs(output, output_scale, input, validate_fp4_experts_quant_inputs(output, output_scale, input,
@@ -428,11 +432,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k); output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0); auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const torch::stable::accelerator::DeviceGuard device_guard(
const cudaStream_t stream = input.get_device_index());
at::cuda::getCurrentCUDAStream(input.get_device()); const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_DISPATCH_HALF_TYPES( VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] { input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>( vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(

View File

@@ -0,0 +1,172 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 quantization kernel");
}
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::stable::empty(
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
torch::stable::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::stable::empty(
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
} else {
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
torch::headeronly::ScalarType::Byte,
std::nullopt, device);
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -14,16 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "cuda_utils.h" #include "cuda_utils.h"
#include "launch_bounds_utils.h" #include "launch_bounds_utils.h"
@@ -173,18 +173,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm } // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::Tensor const& input, torch::stable::Tensor const& input,
torch::Tensor const& output_sf, torch::stable::Tensor const& output_sf,
torch::Tensor const& input_sf, torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) { bool is_sf_swizzled_layout) {
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1); int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half || STD_TORCH_CHECK(
input.scalar_type() == at::ScalarType::BFloat16, input.scalar_type() == torch::headeronly::ScalarType::Half ||
"Unsupported input data type for quantize_to_fp4."); input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount = int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1); get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -192,8 +193,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr()); auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr()); auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr()); auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const torch::stable::accelerator::DeviceGuard device_guard(
auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
@@ -213,15 +215,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { VLLM_STABLE_DISPATCH_HALF_TYPES(
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; input.scalar_type(), "nvfp4_quant_kernel", [&] {
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
// NOTE: We don't support e8m0 scales at this moment. auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>( vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr, m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr), reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out)); reinterpret_cast<uint32_t*>(sf_out));
}); });
} else { } else {
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD; int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x)); int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
@@ -229,15 +231,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y); dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { VLLM_STABLE_DISPATCH_HALF_TYPES(
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; input.scalar_type(), "nvfp4_quant_kernel", [&] {
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
// NOTE: We don't support e8m0 scales at this moment. auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false> vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols, <<<grid, block, 0, stream>>>(
input_ptr, input_sf_ptr, m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr), reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out)); reinterpret_cast<uint32_t*>(sf_out));
}); });
} }
} }

View File

@@ -14,32 +14,39 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100 #if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha); torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif #endif
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120 #if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha); torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif #endif
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A, void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
const torch::Tensor& B, const torch::Tensor& A_sf, const torch::stable::Tensor& A,
const torch::Tensor& B_sf, const torch::stable::Tensor& B,
const torch::Tensor& alpha) { const torch::stable::Tensor& A_sf,
// Make sure were on As device. const torch::stable::Tensor& B_sf,
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A)); const torch::stable::Tensor& alpha) {
// Make sure we're on A's device.
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const int32_t sm = get_sm_version_num(); const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 #if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
@@ -56,8 +63,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
} }
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm, STD_TORCH_CHECK_NOT_IMPLEMENTED(
". Recompile with CUDA >= 12.8 and CC >= 100."); false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
} }
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {

View File

@@ -14,10 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h> #include "libtorch_stable/torch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
@@ -127,8 +126,9 @@ struct Fp4GemmSm100 {
template <typename Config> template <typename Config>
typename Config::Gemm::Arguments args_from_options( typename Config::Gemm::Arguments args_from_options(
at::Tensor& D, at::Tensor const& A, at::Tensor const& B, torch::stable::Tensor& D, torch::stable::Tensor const& A,
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
int64_t M, int64_t N, int64_t K) { int64_t M, int64_t N, int64_t K) {
using ElementA = typename Config::Gemm::ElementA; using ElementA = typename Config::Gemm::ElementA;
using ElementB = typename Config::Gemm::ElementB; using ElementB = typename Config::Gemm::ElementB;
@@ -174,19 +174,20 @@ typename Config::Gemm::Arguments args_from_options(
} }
template <typename Config> template <typename Config>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
at::Tensor const& A_sf, at::Tensor const& B_sf, torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, torch::stable::Tensor const& B_sf,
cudaStream_t stream) { torch::stable::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
typename Config::Gemm gemm; typename Config::Gemm gemm;
auto arguments = auto arguments =
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k); args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Config::Gemm::get_workspace_size(arguments); size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
auto const workspace_options = auto workspace =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
auto workspace = torch::empty(workspace_size, workspace_options); std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments)); CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -197,12 +198,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
// Dispatch function to select appropriate config based on M // Dispatch function to select appropriate config based on M
template <typename OutType> template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int64_t m, int64_t n, torch::stable::Tensor const& B_sf,
int64_t k, cudaStream_t stream) { torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 16) { if (mp2 <= 16) {
@@ -222,61 +224,65 @@ void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
#else #else
template <typename OutType> template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int64_t m, int64_t n, torch::stable::Tensor const& B_sf,
int64_t k, cudaStream_t stream) { torch::stable::Tensor const& alpha, int64_t m,
TORCH_CHECK(false, int64_t n, int64_t k, cudaStream_t stream) {
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " STD_TORCH_CHECK(false,
"a CUTLASS 3.8 source directory to enable support."); "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
} }
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \ #define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \ #define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \ #define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \ CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m) CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha) { torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix"); STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix"); STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1], STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.sizes()[0], "x", "a and b shapes cannot be multiplied (", A.size(0), "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0]; auto const m = A.size(0);
auto const n = B.sizes()[0]; auto const n = B.size(0);
auto const k = A.sizes()[1] * 2; auto const k = A.size(1) * 2;
constexpr int alignment = 32; constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, "."); "), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128); int rounded_m = round_up(m, 128);
@@ -285,33 +291,34 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
// integer. // integer.
int rounded_k = round_up(k / 16, 4); int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (", "scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
"x", B_sf.sizes()[1], ")"); B_sf.size(1), ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m, "scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", "x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.sizes()[1], ")"); A_sf.size(1), ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n, "scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", "x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.sizes()[1], ")"); B_sf.size(1), ")");
auto out_dtype = D.dtype(); auto out_dtype = D.scalar_type();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const torch::stable::accelerator::DeviceGuard device_guard(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::Half) { if (out_dtype == torch::headeronly::ScalarType::Half) {
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
k, stream); k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) { } else if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
m, n, k, stream); m, n, k, stream);
} else { } else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype, STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (",
")"); out_dtype, ")");
} }
} }

View File

@@ -14,10 +14,9 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h> #include "libtorch_stable/torch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp" #include "cutlass_extensions/common.hpp"
@@ -34,19 +33,20 @@
using namespace cute; using namespace cute;
#define CHECK_TYPE(x, st, m) \ #define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \ #define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \ #define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \ #define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \ CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \ CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m) CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
struct sm120_fp4_config_M256 { struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>; using ClusterShape = Shape<_1, _1, _1>;
@@ -109,12 +109,13 @@ struct Fp4GemmSm120 {
}; };
template <typename Gemm> template <typename Gemm>
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A, typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
at::Tensor const& B, torch::stable::Tensor const& A,
at::Tensor const& A_sf, torch::stable::Tensor const& B,
at::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int M, torch::stable::Tensor const& B_sf,
int N, int K) { torch::stable::Tensor const& alpha,
int M, int N, int K) {
using ElementA = typename Gemm::ElementA; using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB; using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
} }
template <typename Gemm> template <typename Gemm>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
at::Tensor const& A_sf, at::Tensor const& B_sf, torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int M, int N, int K, torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int M, int N, int K,
cudaStream_t stream) { cudaStream_t stream) {
Gemm gemm; Gemm gemm;
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K); auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
size_t workspace_size = Gemm::get_workspace_size(arguments); size_t workspace_size = Gemm::get_workspace_size(arguments);
auto const workspace_options = auto workspace =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
auto workspace = torch::empty(workspace_size, workspace_options); std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments)); CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
} }
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int m, int n, torch::stable::Tensor const& B_sf,
int k, cudaStream_t stream) { torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) { if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>( runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
} }
} }
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha, int m, int n, torch::stable::Tensor const& B_sf,
int k, cudaStream_t stream) { torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m)); uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) { if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>( runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
} }
} }
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A, void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::Tensor const& B, torch::stable::Tensor const& A,
torch::Tensor const& A_sf, torch::stable::Tensor const& B,
torch::Tensor const& B_sf, torch::stable::Tensor const& A_sf,
torch::Tensor const& alpha) { torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) #if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix"); STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix"); STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1], STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.sizes()[0], "x", "a and b shapes cannot be multiplied (", A.size(0), "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0]; auto const m = A.size(0);
auto const n = B.sizes()[0]; auto const n = B.size(0);
auto const k = A.sizes()[1] * 2; auto const k = A.size(1) * 2;
constexpr int alignment = 32; constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, "."); "), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128); int rounded_m = round_up(m, 128);
@@ -248,38 +254,39 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
// integer. // integer.
int rounded_k = round_up(k / 16, 4); int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (", "scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
"x", B_sf.sizes()[1], ")"); B_sf.size(1), ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m, "scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", "x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.sizes()[1], ")"); A_sf.size(1), ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n, "scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", "x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.sizes()[1], ")"); B_sf.size(1), ")");
auto out_dtype = D.dtype(); auto out_dtype = D.scalar_type();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const torch::stable::accelerator::DeviceGuard device_guard(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::BFloat16) { if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream); stream);
} else if (out_dtype == at::ScalarType::Half) { } else if (out_dtype == torch::headeronly::ScalarType::Half) {
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream); stream);
} else { } else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (", STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")"); out_dtype, ")");
} }
#else #else
TORCH_CHECK(false, STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."); "a CUTLASS 3.8 source directory to enable support.");
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED) #endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
} }

View File

@@ -20,7 +20,7 @@
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <utility> #include <utility>
#include "../../cuda_vec_utils.cuh" #include "cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \ #if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090 CUDA_VERSION >= 12090

View File

@@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
ops.def( ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool"); "bool");
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
// Out variant
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
// This registration is now migrated to stable ABI
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
// yet in torch/headeronly), the tag should be applied from Python
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
// with the .impl remaining in C++.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
#endif #endif
} }
@@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets)); TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data", ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data)); TORCH_BOX(&get_cutlass_batched_moe_mm_data));
// FP4/NVFP4 ops
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
#endif #endif
} }
@@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
TORCH_BOX(&cutlass_group_gemm_supported)); TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8", ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8)); TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
ops.impl("cutlass_scaled_mm_supports_fp4",
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
#endif #endif
} }

View File

@@ -2,6 +2,7 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h> #include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h> #include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h> #include <torch/headeronly/util/shim_utils.h>

View File

@@ -152,12 +152,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale); torch::Tensor& scale);
#ifndef USE_ROCM
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& output_block_scale,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void persistent_masked_m_silu_mul_quant( void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H) const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E) const at::Tensor& counts, // (E)
@@ -225,44 +219,6 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
int64_t ggml_moe_get_block_size(int64_t type); int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale, torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp); std::optional<torch::Tensor> const& azp);

View File

@@ -1,163 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
torch::Tensor& output_sf,
torch::Tensor& input,
torch::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::empty(
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
torch::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::empty(
{sf_m, sf_n},
torch::TensorOptions().device(device).dtype(torch::kInt32));
} else {
output_sf = torch::empty(
{m, n / CVT_FP4_SF_VEC_SIZE},
torch::TensorOptions().device(device).dtype(torch::kUInt8));
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
torch::Tensor& input, torch::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -109,13 +109,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
#ifndef USE_ROCM
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant);
#endif
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
@@ -332,47 +325,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? qzeros_or_none, bool inplace) -> Tensor"); "Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file // conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// conditionally compiled so impl registration is in source file
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
// conditionally compiled so impl registration is in source file
#endif #endif
// Dequantization for GGML. // Dequantization for GGML.
@@ -409,20 +361,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM #ifndef USE_ROCM
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// conditionally compiled so impl registration is in source file
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+). // Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
ops.def( ops.def(
"mxfp8_experts_quant(" "mxfp8_experts_quant("
@@ -455,44 +393,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> int"); "-> int");
// conditionally compiled so impl in source file // conditionally compiled so impl in source file
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
// Out variant
// TODO: Add {at::Tag::out_variant} tag and update all call sites
// to use the functional variant once vLLM upgrades PyTorch.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
&silu_and_mul_scaled_fp4_experts_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
#endif #endif
// Quantized GEMM for GPTQ. // Quantized GEMM for GPTQ.