[4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
mikaylagawarecki
2026-03-31 13:21:13 -04:00
committed by GitHub
parent 0dd25a44ea
commit 7c080dd3c5
27 changed files with 1205 additions and 1016 deletions

View File

@@ -340,8 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_EXT_SRC
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs(
@@ -489,59 +487,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# CUTLASS MLA Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
@@ -681,34 +626,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
# Hadacore kernels
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
@@ -760,7 +677,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(VLLM_STABLE_EXT_SRC
"csrc/libtorch_stable/torch_bindings.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu")
"csrc/cuda_utils_kernels.cu"
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_STABLE_EXT_SRC
@@ -978,6 +898,96 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
#
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
#
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
# CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
# FP4 Archs and flags
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building NVFP4 as no compatible archs were found.")
# clear FP4_ARCHS
set(FP4_ARCHS)
endif()
#
# W4A8 kernels (moved from _C to _C_stable_libtorch)
#
# Only build W4A8 kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
set(SRCS
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${W4A8_ARCHS}")
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
AND W4A8_ARCHS)
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building W4A8 kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
message(STATUS "Enabling C_stable extension.")
define_extension_target(
_C_stable_libtorch

View File

@@ -3,8 +3,8 @@
#pragma once
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <cassert>
#ifdef USE_ROCM

View File

@@ -1,7 +1,6 @@
#pragma once
#include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute {
////////////////////////////////////////////////////////////////////

View File

@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(

View File

@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cute::Tensor cCol = make_identity_tensor(mCol.shape());
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(

View File

@@ -1,6 +1,21 @@
#pragma once
#include <torch/all.h>
// This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
@@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor,
static inline auto make_cute_layout(TorchTensor const& tensor,
std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
return tensor.stride(idx);
}
});
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
@@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
template <typename Stride>
static inline auto maybe_make_cute_layout(
std::optional<torch::Tensor> const& tensor,
std::optional<TorchTensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
@@ -121,12 +136,12 @@ template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<c10::Half> {
struct equivalent_cutlass_type<torch::headeronly::Half> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<c10::BFloat16> {
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
using type = cutlass::bfloat16_t;
};
@@ -134,8 +149,8 @@ struct equivalent_cutlass_type<c10::BFloat16> {
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
@@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = c10::Half;
using type = torch::headeronly::Half;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = c10::BFloat16;
using type = torch::headeronly::BFloat16;
};
// get equivalent c10::ScalarType tag from compile time type
// get equivalent torch::headeronly::ScalarType tag from compile time type
template <typename T>
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;

View File

@@ -49,6 +49,15 @@
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
// Half types dispatch (Half + BFloat16)
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
// Boolean dispatch
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \

View File

@@ -84,4 +84,54 @@ void get_cutlass_batched_moe_mm_data(
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
// FP4/NVFP4 ops
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets);
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_scale,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
torch::stable::Tensor& output_block_scale,
torch::stable::Tensor& input,
torch::stable::Tensor& input_global_scale);
#endif

View File

@@ -2,10 +2,9 @@
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
@@ -41,7 +40,7 @@ __global__ void get_group_gemm_starts(
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
cutlass::Array<cutlass::float_e4m3_t, 8>> \
<<<1, num_experts, 0, stream>>>( \
@@ -66,23 +65,34 @@ __global__ void get_group_gemm_starts(
namespace {
void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_group_scales.dtype() ==
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
TORCH_CHECK(out_tensors.dtype() ==
torch::kBFloat16); // only support bf16 for now
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
torch::stable::Tensor& b_group_scales_ptrs,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) {
STD_TORCH_CHECK(a_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(
b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Int); // int4 8x packed into int32
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(
b_group_scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
// type is e4m3
STD_TORCH_CHECK(
out_tensors.scalar_type() ==
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
torch::headeronly::ScalarType::Long);
int num_experts = static_cast<int>(expert_offsets.size(0));
// logical k, n
@@ -90,14 +100,15 @@ void run_get_group_gemm_starts(
int64_t k = a_tensors.size(1);
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}

View File

@@ -14,13 +14,12 @@
#include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"
#include "core/registration.h"
#include "get_group_starts.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "w4a8_utils.cuh"
@@ -168,31 +167,40 @@ struct W4A8GroupedGemmKernel {
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
"LayoutB_Reordered size must be divisible by 4 bytes");
static void grouped_mm(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides) {
static void grouped_mm(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales,
const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes_torch,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides) {
auto device = a_tensors.device();
auto device_id = device.index();
const at::cuda::OptionalCUDAGuard device_guard(device);
auto stream = at::cuda::getCurrentCUDAStream(device_id);
const torch::stable::accelerator::DeviceGuard device_guard(device_id);
auto stream = get_current_cuda_stream(device_id);
int num_experts = static_cast<int>(expert_offsets.size(0));
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(device);
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor out_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
// get the correct offsets to pass to gemm
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
@@ -247,9 +255,9 @@ struct W4A8GroupedGemmKernel {
// Allocate workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
torch::stable::Tensor workspace = torch::stable::empty(
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
device);
// Run GEMM
GemmShuffled gemm;
@@ -294,14 +302,20 @@ using Kernel_256x128_2x1x1_Coop =
using Kernel_128x256_2x1x1_Coop =
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
void mm_dispatch(
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides, const std::string& schedule) {
void mm_dispatch(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales,
const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
const std::string& schedule) {
if (schedule == "Kernel_128x16_1x1x1_Coop") {
Kernel_128x16_1x1x1_Coop::grouped_mm(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
@@ -358,18 +372,23 @@ void mm_dispatch(
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, group_scale_strides);
} else {
TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
STD_TORCH_CHECK(false,
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
}
}
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
const int64_t b_group_size, const torch::Tensor& expert_offsets,
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
const torch::Tensor& group_scale_strides,
void mm(torch::stable::Tensor& out_tensors,
const torch::stable::Tensor& a_tensors,
const torch::stable::Tensor& b_tensors,
const torch::stable::Tensor& a_scales,
const torch::stable::Tensor& b_scales,
const torch::stable::Tensor& b_group_scales, const int64_t b_group_size,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
const torch::stable::Tensor& group_scale_strides,
std::optional<std::string> maybe_schedule) {
// user has specified a schedule
if (maybe_schedule) {
@@ -406,26 +425,27 @@ void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
a_strides, b_strides, c_strides, group_scale_strides, schedule);
}
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
torch::Tensor const& b_tensors) {
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
TORCH_CHECK(b_tensors.is_contiguous());
TORCH_CHECK(b_tensors.is_cuda());
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
STD_TORCH_CHECK(b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
STD_TORCH_CHECK(b_tensors.is_contiguous());
STD_TORCH_CHECK(b_tensors.is_cuda());
int n = static_cast<int>(b_tensors.size(1));
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
// These misalignments cause silent OOB unless run under Compute Sanitizer.
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
// we will store the layout to an int32 tensor;
// this is the number of elements we need per layout
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors);
int num_experts = static_cast<int>(b_tensors.size(0));
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
@@ -435,7 +455,7 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
size_t num_int4_elems = 1ull * num_experts * n * k;
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
num_int4_elems);
TORCH_CHECK(ok, "unified_encode_int4b failed");
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
// construct the layout once; assumes each expert has the same layout
using LayoutType = LayoutB_Reordered;
@@ -456,28 +476,28 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
}
// save the packed layout to torch tensor so we can re-use it
auto cpu_opts =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
torch::Tensor layout_cpu =
torch::empty({num_experts, layout_width}, cpu_opts);
torch::stable::Tensor layout_cpu = torch::stable::empty(
{num_experts, layout_width}, torch::headeronly::ScalarType::Int,
std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU));
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>();
for (int i = 0; i < num_experts; ++i) {
std::memcpy(layout_data + i * layout_width, // dst (int32*)
&layout_B_reordered, // src (LayoutType*)
sizeof(LayoutType)); // number of bytes
}
torch::Tensor packed_layout =
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
torch::stable::Tensor packed_layout =
torch::stable::to(layout_cpu, b_tensors.device(),
/*non_blocking=*/false);
return {b_tensors_packed, packed_layout};
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", &mm);
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm));
m.impl("cutlass_encode_and_reorder_int4b_grouped",
TORCH_BOX(&encode_and_reorder_int4b));
}
} // namespace vllm::cutlass_w4a8_moe
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -3,14 +3,12 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
//
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
#include "core/registration.h"
#include "cutlass/cutlass.h"
#include <limits>
@@ -161,31 +159,31 @@ struct W4A8GemmKernel {
using StrideD = typename GemmKernelShuffled::StrideD;
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
static torch::Tensor mm(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size,
torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type) {
static torch::stable::Tensor mm(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type) {
// TODO: param validation
int m = A.size(0);
int k = A.size(1);
int n = B.size(1);
// safely cast group_size to int
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
STD_TORCH_CHECK(
group_size > 0 && group_size <= std::numeric_limits<int>::max(),
"group_size out of supported range for int: ", group_size);
int const group_size_int = static_cast<int>(group_size);
// Allocate output
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
auto device = A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
torch::Tensor D =
torch::empty({m, n}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<ElementD>)
.device(device));
auto stream = get_current_cuda_stream(device.index());
torch::stable::Tensor D = torch::stable::empty(
{m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device);
// prepare arg pointers
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
@@ -237,9 +235,9 @@ struct W4A8GemmKernel {
// Workspace
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
torch::Tensor workspace =
torch::empty(workspace_size,
torch::TensorOptions().dtype(torch::kU8).device(device));
torch::stable::Tensor workspace = torch::stable::empty(
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
device);
// Run GEMM
GemmShuffled gemm;
@@ -269,14 +267,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
torch::Tensor mm_dispatch(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size,
torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type,
const std::string& schedule) {
torch::stable::Tensor mm_dispatch(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
const std::string& schedule) {
if (schedule == "256x128_1x1x1") {
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
channel_scales, token_scales,
@@ -318,17 +316,18 @@ torch::Tensor mm_dispatch(torch::Tensor const& A,
channel_scales, token_scales,
maybe_out_type);
}
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
return {};
}
torch::Tensor mm(torch::Tensor const& A,
torch::Tensor const& B, // already packed
torch::Tensor const& group_scales, // already packed
int64_t group_size, torch::Tensor const& channel_scales,
torch::Tensor const& token_scales,
std::optional<at::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
torch::stable::Tensor mm(
torch::stable::Tensor const& A,
torch::stable::Tensor const& B, // already packed
torch::stable::Tensor const& group_scales, // already packed
int64_t group_size, torch::stable::Tensor const& channel_scales,
torch::stable::Tensor const& token_scales,
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
std::optional<std::string> maybe_schedule) {
// requested a specific schedule
if (maybe_schedule) {
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
@@ -378,14 +377,15 @@ torch::Tensor mm(torch::Tensor const& A,
// ----------------------------------------------------------------------------
// Pre-processing utils
// ----------------------------------------------------------------------------
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(scales.is_contiguous());
TORCH_CHECK(scales.is_cuda());
torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
STD_TORCH_CHECK(scales.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(scales.is_contiguous());
STD_TORCH_CHECK(scales.is_cuda());
auto packed_scales = torch::empty(
{scales.numel() * ScalePackSize},
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
auto packed_scales =
torch::stable::empty({scales.numel() * ScalePackSize},
scales.scalar_type(), std::nullopt, scales.device());
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
auto packed_scales_ptr =
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
@@ -396,15 +396,16 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
return packed_scales;
}
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
TORCH_CHECK(B.dtype() == torch::kInt32);
TORCH_CHECK(B.dim() == 2);
torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(B.dim() == 2);
torch::Tensor B_packed = torch::empty_like(B);
torch::stable::Tensor B_packed = torch::stable::empty_like(B);
int k = B.size(0) * PackFactor; // logical k
int n = B.size(1);
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
STD_TORCH_CHECK((n * k) % 32 == 0,
"need multiples of 32 int4s for 16B chunks");
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
@@ -415,16 +416,17 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
n * k);
TORCH_CHECK(ok, "unified_encode_int4b failed");
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
return B_packed;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_w4a8_mm", &mm);
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm));
m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8));
m.impl("cutlass_encode_and_reorder_int4b",
TORCH_BOX(&encode_and_reorder_int4b));
}
} // namespace vllm::cutlass_w4a8

View File

@@ -14,16 +14,15 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -118,17 +117,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
torch::Tensor& output_sf,
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& input_sf) {
void silu_and_mul_nvfp4_quant_sm1xxa(
torch::stable::Tensor& output, // [..., d]
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input, // [..., 2 * d]
torch::stable::Tensor& input_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1) / 2;
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -136,8 +137,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
@@ -149,7 +151,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());

View File

@@ -14,14 +14,12 @@
* limitations under the License.
*/
#include "core/registration.h"
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"
@@ -122,7 +120,7 @@ __global__ void __get_group_gemm_starts(
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
LayoutSFB, ScaleConfig) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
LayoutSFA, LayoutSFB, ScaleConfig> \
<<<1, num_experts, 0, stream>>>( \
@@ -150,50 +148,64 @@ __global__ void __get_group_gemm_starts(
}
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
void run_get_group_gemm_starts(
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& alphas,
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
torch::Tensor const& problem_sizes, int M, int N, int K) {
void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
const torch::stable::Tensor& b_starts,
const torch::stable::Tensor& out_starts,
const torch::stable::Tensor& a_scales_starts,
const torch::stable::Tensor& b_scales_starts,
const torch::stable::Tensor& alpha_starts,
const torch::stable::Tensor& layout_sfa,
const torch::stable::Tensor& layout_sfb,
const torch::stable::Tensor& a_strides,
const torch::stable::Tensor& b_strides,
const torch::stable::Tensor& c_strides,
int64_t a_stride_val, int64_t b_stride_val,
int64_t c_stride_val,
/*these are used for their base addresses*/
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& alphas,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& sf_offsets,
torch::stable::Tensor const& problem_sizes,
int M, int N, int K) {
int num_experts = (int)expert_offsets.size(0);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
STD_TORCH_CHECK(out_tensors.size(1) == N,
"Output tensor shape doesn't match expected shape");
STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
" dimension must match");
if (false) {
}
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
// ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
cutlass::float_e2m1_t, cutlass::float_ue4m3_t,
torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA,
LayoutSFB, ScaleConfig)
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
cutlass::float_ue4m3_t, torch::kFloat16,
half, LayoutSFA, LayoutSFB, ScaleConfig)
cutlass::float_ue4m3_t,
torch::headeronly::ScalarType::Half, half,
LayoutSFA, LayoutSFB, ScaleConfig)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm_sm100(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
@@ -272,20 +284,40 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -308,7 +340,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
@@ -350,32 +382,35 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
void run_fp4_blockwise_scaled_group_mm_sm120(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
using ProblemShape =
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
using ElementType = cutlass::float_e2m1_t;
@@ -446,20 +481,40 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
int num_experts = static_cast<int>(expert_offsets.size(0));
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor out_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor a_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_scales_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor alpha_ptrs =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor layout_sfa = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor layout_sfb = torch::stable::empty(
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
a.device());
torch::stable::Tensor a_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor b_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
torch::stable::Tensor c_strides1 =
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
std::nullopt, a.device());
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
@@ -480,7 +535,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
scheduler.raster_order = RasterOrderOptions::AlongM;
hw_info.device_id = a.get_device();
hw_info.device_id = a.get_device_index();
static std::unordered_map<int, int> cached_sm_counts;
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
cached_sm_counts[hw_info.device_id] =
@@ -523,33 +578,36 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
scheduler};
size_t workspace_size = Gemm::get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, a.device());
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
STD_TORCH_CHECK(
can_implement_status == cutlass::Status::kSuccess,
"Failed to implement GEMM: status=", (int)can_implement_status);
// Run the GEMM
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size, " num_experts=", num_experts,
" M=", M, " N=", N, " K=", K);
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
"Failed to initialize GEMM: status=", (int)status,
" workspace_size=", workspace_size,
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
status = gemm_op.run(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
}
template <typename OutType>
void run_fp4_blockwise_scaled_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
int N, int K) {
torch::stable::Tensor& output, const torch::stable::Tensor& a,
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
if (version_num >= 120 && version_num < 130) {
@@ -567,7 +625,7 @@ void run_fp4_blockwise_scaled_group_mm(
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 100 or 120");
@@ -575,26 +633,31 @@ void run_fp4_blockwise_scaled_group_mm(
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
#endif
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
const torch::stable::Tensor& a,
const torch::stable::Tensor& b,
const torch::stable::Tensor& a_blockscale,
const torch::stable::Tensor& b_blockscales,
const torch::stable::Tensor& alphas,
const torch::stable::Tensor& problem_sizes,
const torch::stable::Tensor& expert_offsets,
const torch::stable::Tensor& sf_offsets) {
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
// Input validation
@@ -602,30 +665,34 @@ void cutlass_fp4_group_mm(
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas");
TORCH_CHECK(a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
"problem_sizes must be int32.");
STD_TORCH_CHECK(
a_blockscale.dim() == 2,
"expected a_blockscale to be of shape [num_experts, rounded_m,"
" k // group_size], observed rank: ",
a_blockscale.dim())
STD_TORCH_CHECK(b_blockscales.dim() == 3,
"expected b_blockscale to be of shape: "
" [num_experts, n, k // group_size], observed rank: ",
b_blockscales.dim())
STD_TORCH_CHECK(problem_sizes.dim() == 2,
"problem_sizes must be a 2D tensor");
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
"problem_sizes must have the shape (num_experts, 3)");
STD_TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0),
"Number of experts in problem_sizes must match expert_offsets");
STD_TORCH_CHECK(
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32.");
int M = static_cast<int>(a.size(0));
int N = static_cast<int>(b.size(1));
int E = static_cast<int>(b.size(0));
int K = static_cast<int>(2 * b.size(2));
if (output.scalar_type() == torch::kBFloat16) {
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
expert_offsets, sf_offsets, M, N, K);
@@ -633,7 +700,7 @@ void cutlass_fp4_group_mm(
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
int32_t version_num = get_sm_version_num();
if (version_num >= 120 && version_num < 130) {
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
output.scalar_type());
}
@@ -643,7 +710,7 @@ void cutlass_fp4_group_mm(
expert_offsets, sf_offsets, M, N, K);
}
#else
TORCH_CHECK_NOT_IMPLEMENTED(
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
@@ -651,6 +718,6 @@ void cutlass_fp4_group_mm(
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm));
}

View File

@@ -14,16 +14,15 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
@@ -327,25 +326,28 @@ void quant_impl(void* output, void* output_scale, void* input,
} // namespace vllm
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_TH_CUDA(x, m) \
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m);
constexpr auto HALF = at::ScalarType::Half;
constexpr auto BF16 = at::ScalarType::BFloat16;
constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
constexpr auto HALF = torch::headeronly::ScalarType::Half;
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
constexpr auto FLOAT = torch::headeronly::ScalarType::Float;
constexpr auto INT = torch::headeronly::ScalarType::Int;
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
// Common validation for fp4 experts quantization entry points.
static void validate_fp4_experts_quant_inputs(
torch::Tensor const& output, torch::Tensor const& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
torch::stable::Tensor const& output,
torch::stable::Tensor const& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
int64_t k) {
CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale");
@@ -354,41 +356,42 @@ static void validate_fp4_experts_quant_inputs(
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output.dim() == 2);
STD_TORCH_CHECK(output_scale.dim() == 2);
STD_TORCH_CHECK(input.dim() == 2);
STD_TORCH_CHECK(input_global_scale.dim() == 1);
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
TORCH_CHECK(output.scalar_type() == UINT8);
TORCH_CHECK(output_scale.scalar_type() == INT);
STD_TORCH_CHECK(output.scalar_type() == UINT8);
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
STD_TORCH_CHECK(output.size(0) == m_topk);
STD_TORCH_CHECK(output.size(1) == k / 2);
int scales_k = k / BLOCK_SIZE;
// 4 means the swizzle requirement by nvidia nvfp4.
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
}
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
auto k = input.size(1);
@@ -397,11 +400,11 @@ void scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
@@ -413,14 +416,15 @@ void scaled_fp4_experts_quant_sm1xxa(
}
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
// Input has gate || up layout, so k = input.size(1) / 2
auto k_times_2 = input.size(1);
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2;
validate_fp4_experts_quant_inputs(output, output_scale, input,
@@ -428,11 +432,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
VLLM_DISPATCH_HALF_TYPES(
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(

View File

@@ -0,0 +1,172 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts);
#endif
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout,
torch::stable::Tensor& output,
torch::stable::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 quantization kernel");
}
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::stable::empty(
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
torch::stable::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::stable::empty(
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
} else {
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
torch::headeronly::ScalarType::Byte,
std::nullopt, device);
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
torch::stable::Tensor& output_sf,
torch::stable::Tensor& input,
torch::stable::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
torch::stable::Tensor const& input,
torch::stable::Tensor const& input_global_scale,
torch::stable::Tensor const& input_offset_by_experts,
torch::stable::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -14,16 +14,16 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "libtorch_stable/torch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include "cuda_vec_utils.cuh"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
@@ -173,18 +173,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} // namespace vllm
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
torch::stable::Tensor const& input,
torch::stable::Tensor const& output_sf,
torch::stable::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
input.scalar_type() == at::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
STD_TORCH_CHECK(
input.scalar_type() == torch::headeronly::ScalarType::Half ||
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
"Unsupported input data type for quantize_to_fp4.");
int multiProcessorCount =
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
@@ -192,8 +193,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
const torch::stable::accelerator::DeviceGuard device_guard(
input.get_device_index());
auto stream = get_current_cuda_stream(input.get_device_index());
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
@@ -213,15 +215,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
} else {
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
@@ -229,15 +231,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
VLLM_STABLE_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(
m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}
}

View File

@@ -14,32 +14,39 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha);
#endif
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
const torch::Tensor& B, const torch::Tensor& A_sf,
const torch::Tensor& B_sf,
const torch::Tensor& alpha) {
// Make sure were on As device.
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
const torch::stable::Tensor& A,
const torch::stable::Tensor& B,
const torch::stable::Tensor& A_sf,
const torch::stable::Tensor& B_sf,
const torch::stable::Tensor& alpha) {
// Make sure we're on A's device.
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
@@ -56,8 +63,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled nvfp4 mm kernel for SM ", sm,
". Recompile with CUDA >= 12.8 and CC >= 100.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {

View File

@@ -14,10 +14,9 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
@@ -127,8 +126,9 @@ struct Fp4GemmSm100 {
template <typename Config>
typename Config::Gemm::Arguments args_from_options(
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
int64_t M, int64_t N, int64_t K) {
using ElementA = typename Config::Gemm::ElementA;
using ElementB = typename Config::Gemm::ElementB;
@@ -174,19 +174,20 @@ typename Config::Gemm::Arguments args_from_options(
}
template <typename Config>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
cudaStream_t stream) {
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
typename Config::Gemm gemm;
auto arguments =
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -197,12 +198,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
// Dispatch function to select appropriate config based on M
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 16) {
@@ -222,61 +224,65 @@ void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
#else
template <typename OutType>
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int64_t m, int64_t n,
int64_t k, cudaStream_t stream) {
TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int64_t m,
int64_t n, int64_t k, cudaStream_t stream) {
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
}
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.size(0), "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
auto const m = A.size(0);
auto const n = B.size(0);
auto const k = A.size(1) * 2;
constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
@@ -285,33 +291,34 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
// integer.
int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
"x", B_sf.sizes()[1], ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.sizes()[1], ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.sizes()[1], ")");
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
B_sf.size(1), ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.size(1), ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.size(1), ")");
auto out_dtype = D.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
auto out_dtype = D.scalar_type();
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::Half) {
if (out_dtype == torch::headeronly::ScalarType::Half) {
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
k, stream);
} else if (out_dtype == at::ScalarType::BFloat16) {
} else if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
m, n, k, stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
")");
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (",
out_dtype, ")");
}
}

View File

@@ -14,10 +14,9 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
@@ -34,19 +33,20 @@
using namespace cute;
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TYPE(x, st, m) \
STD_TORCH_CHECK(x.scalar_type() == st, \
": Inconsistency of torch::stable::Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>;
@@ -109,12 +109,13 @@ struct Fp4GemmSm120 {
};
template <typename Gemm>
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
torch::Tensor const& alpha, int M,
int N, int K) {
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha,
int M, int N, int K) {
using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD;
@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
}
template <typename Gemm>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
torch::Tensor const& alpha, int M, int N, int K,
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int M, int N, int K,
cudaStream_t stream) {
Gemm gemm;
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, A.device());
CUTLASS_CHECK(gemm.can_implement(arguments));
@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
}
}
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha, int m, int n,
int k, cudaStream_t stream) {
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
}
}
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha) {
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha) {
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
STD_TORCH_CHECK(A.size(1) == B.size(1),
"a and b shapes cannot be multiplied (", A.size(0), "x",
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
auto const m = A.size(0);
auto const n = B.size(0);
auto const k = A.size(1) * 2;
constexpr int alignment = 32;
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
"), k: ", k, ".");
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
"), k: ", k, ".");
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
@@ -248,38 +254,39 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
// integer.
int rounded_k = round_up(k / 16, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
"x", B_sf.sizes()[1], ")");
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
A_sf.sizes()[1], ")");
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
B_sf.sizes()[1], ")");
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
B_sf.size(1), ")");
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
"scale_a must be padded and swizzled to a shape (", rounded_m,
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
A_sf.size(1), ")");
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
"scale_b must be padded and swizzled to a shape (", rounded_n,
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
B_sf.size(1), ")");
auto out_dtype = D.dtype();
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
auto out_dtype = D.scalar_type();
const torch::stable::accelerator::DeviceGuard device_guard(
A.get_device_index());
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
if (out_dtype == at::ScalarType::BFloat16) {
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else if (out_dtype == at::ScalarType::Half) {
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else {
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
out_dtype, ")");
}
#else
TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
STD_TORCH_CHECK(false,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support.");
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
}

View File

@@ -20,7 +20,7 @@
#include <cuda_fp8.h>
#include <utility>
#include "../../cuda_vec_utils.cuh"
#include "cuda_vec_utils.cuh"
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12090

View File

@@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
// Out variant
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
// This registration is now migrated to stable ABI
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
// yet in torch/headeronly), the tag should be applied from Python
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
// with the .impl remaining in C++.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
// Fused SiLU+Mul+NVFP4 quantization.
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
#endif
}
@@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
// FP4/NVFP4 ops
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
// W4A8 ops: impl registrations are in the source files
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
#endif
}
@@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
ops.impl("cutlass_scaled_mm_supports_fp4",
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
#endif
}

View File

@@ -2,6 +2,7 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h>

View File

@@ -152,12 +152,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
#ifndef USE_ROCM
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
torch::Tensor& output_block_scale,
torch::Tensor& input,
torch::Tensor& input_global_scale);
#endif
void persistent_masked_m_silu_mul_quant(
const at::Tensor& input, // (E, T, 2*H)
const at::Tensor& counts, // (E)
@@ -225,44 +219,6 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_scale,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_scale);
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale,
std::optional<torch::Tensor> const& azp);

View File

@@ -1,163 +0,0 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input,
torch::Tensor const& output_sf,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
torch::Tensor& output_sf,
torch::Tensor& input,
torch::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
static bool nvfp4_quant_sm_supported() {
const int32_t sm = get_sm_version_num();
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
if (sm >= 100 && sm < 120) return true;
#endif
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
if (sm >= 120 && sm < 130) return true;
#endif
return false;
}
void scaled_fp4_quant_out(torch::Tensor const& input,
torch::Tensor const& input_sf,
bool is_sf_swizzled_layout, torch::Tensor& output,
torch::Tensor& output_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
}
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int64_t n = input.size(-1);
int64_t m = input.numel() / n;
auto device = input.device();
// Two fp4 values packed into a uint8
auto output = torch::empty(
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
torch::Tensor output_sf;
if (is_sf_swizzled_layout) {
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
output_sf = torch::empty(
{sf_m, sf_n},
torch::TensorOptions().device(device).dtype(torch::kInt32));
} else {
output_sf = torch::empty(
{m, n / CVT_FP4_SF_VEC_SIZE},
torch::TensorOptions().device(device).dtype(torch::kUInt8));
}
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
output_sf);
return {output, output_sf};
}
void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled nvfp4 experts quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled nvfp4 experts quantization kernel");
}
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
torch::Tensor& input, torch::Tensor& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
TORCH_CHECK(nvfp4_quant_sm_supported(),
"No compiled silu_and_mul nvfp4 experts quantization kernel "
"for SM ",
get_sm_version_num(),
". Recompile with the appropriate CUDA arch.");
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}

View File

@@ -109,13 +109,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
#ifndef USE_ROCM
ops.def(
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
"Tensor input, Tensor input_global_scale) -> ()");
ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant);
#endif
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
@@ -332,47 +325,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
// conditionally compiled so impl registrations are in source file
// CUTLASS w4a8 GEMM
ops.def(
"cutlass_w4a8_mm("
" Tensor A,"
" Tensor B,"
" Tensor group_scales,"
" int group_size,"
" Tensor channel_scales,"
" Tensor token_scales,"
" ScalarType? out_type,"
" str? maybe_schedule"
") -> Tensor");
// pack scales
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
// encode and reorder weight matrix
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
// conditionally compiled so impl registration is in source file
// CUTLASS w4a8 grouped GEMM
ops.def(
"cutlass_w4a8_moe_mm("
" Tensor! out_tensors,"
" Tensor a_tensors,"
" Tensor b_tensors,"
" Tensor a_scales,"
" Tensor b_scales,"
" Tensor b_group_scales,"
" int b_group_size,"
" Tensor expert_offsets,"
" Tensor problem_sizes,"
" Tensor a_strides,"
" Tensor b_strides,"
" Tensor c_strides,"
" Tensor group_scale_strides,"
" str? maybe_schedule"
") -> ()");
ops.def(
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
"Tensor)");
// conditionally compiled so impl registration is in source file
#endif
// Dequantization for GGML.
@@ -409,20 +361,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM
// CUTLASS nvfp4 block scaled GEMM
ops.def(
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor block_scale_a, Tensor block_scale_b,"
" Tensor alpha) -> ()");
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
// cutlass nvfp4 block scaled group GEMM
ops.def(
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
// conditionally compiled so impl registration is in source file
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
ops.def(
"mxfp8_experts_quant("
@@ -455,44 +393,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> int");
// conditionally compiled so impl in source file
// Compute NVFP4 block quantized tensor.
ops.def(
"scaled_fp4_quant(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
// Out variant
// TODO: Add {at::Tag::out_variant} tag and update all call sites
// to use the functional variant once vLLM upgrades PyTorch.
// See pytorch/pytorch#176117.
ops.def(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()");
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
// Compute NVFP4 experts quantization.
ops.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
&silu_and_mul_scaled_fp4_experts_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
#endif
// Quantized GEMM for GPTQ.