[4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
178
CMakeLists.txt
178
CMakeLists.txt
@@ -340,8 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
|
|
||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
|
||||||
"csrc/cutlass_extensions/common.cpp")
|
"csrc/cutlass_extensions/common.cpp")
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
@@ -489,59 +487,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
" in CUDA target architectures")
|
" in CUDA target architectures")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
|
||||||
# CUDA 12.8 or later
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
|
||||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
|
|
||||||
else()
|
|
||||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
|
||||||
endif()
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
|
||||||
set(SRCS
|
|
||||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
|
||||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${FP4_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
|
|
||||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
|
||||||
# clear FP4_ARCHS
|
|
||||||
set(FP4_ARCHS)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# FP4 Archs and flags
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
|
||||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
|
||||||
else()
|
|
||||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
|
||||||
endif()
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
|
||||||
set(SRCS
|
|
||||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
|
||||||
"csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
|
||||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${FP4_ARCHS}")
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
|
||||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
|
||||||
# clear FP4_ARCHS
|
|
||||||
set(FP4_ARCHS)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# CUTLASS MLA Archs and flags
|
# CUTLASS MLA Archs and flags
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||||
@@ -681,34 +626,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
|
||||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
|
||||||
set(SRCS
|
|
||||||
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
|
|
||||||
"csrc/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
|
|
||||||
"csrc/quantization/cutlass_w4a8/w4a8_utils.cu"
|
|
||||||
)
|
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
|
||||||
SRCS "${SRCS}"
|
|
||||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
|
||||||
|
|
||||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
|
||||||
|
|
||||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
|
||||||
else()
|
|
||||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
|
||||||
AND W4A8_ARCHS)
|
|
||||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
|
||||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
|
||||||
"later if you intend on running w4a16 quantized models on "
|
|
||||||
"Hopper.")
|
|
||||||
else()
|
|
||||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
|
||||||
"found in CUDA target architectures")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Hadacore kernels
|
# Hadacore kernels
|
||||||
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0+PTX;9.0+PTX" "${CUDA_ARCHS}")
|
||||||
@@ -760,7 +677,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set(VLLM_STABLE_EXT_SRC
|
set(VLLM_STABLE_EXT_SRC
|
||||||
"csrc/libtorch_stable/torch_bindings.cpp"
|
"csrc/libtorch_stable/torch_bindings.cpp"
|
||||||
"csrc/cutlass_extensions/common.cpp"
|
"csrc/cutlass_extensions/common.cpp"
|
||||||
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu")
|
"csrc/cuda_utils_kernels.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_entry.cu")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
list(APPEND VLLM_STABLE_EXT_SRC
|
list(APPEND VLLM_STABLE_EXT_SRC
|
||||||
@@ -978,6 +898,96 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# FP4/NVFP4 kernels (moved from _C to _C_stable_libtorch)
|
||||||
|
#
|
||||||
|
|
||||||
|
# The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require
|
||||||
|
# CUDA 12.8 or later
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${FP4_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM120=1")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
|
||||||
|
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||||
|
# clear FP4_ARCHS
|
||||||
|
set(FP4_ARCHS)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# FP4 Archs and flags
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_experts_quant.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/fp4/nvfp4_blockwise_moe_kernel.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${FP4_ARCHS}")
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4_SM100=1")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||||
|
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||||
|
# clear FP4_ARCHS
|
||||||
|
set(FP4_ARCHS)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# W4A8 kernels (moved from _C to _C_stable_libtorch)
|
||||||
|
#
|
||||||
|
|
||||||
|
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||||
|
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_mm_entry.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_grouped_mm_entry.cu"
|
||||||
|
"csrc/libtorch_stable/quantization/cutlass_w4a8/w4a8_utils.cu"
|
||||||
|
)
|
||||||
|
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||||
|
|
||||||
|
list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}")
|
||||||
|
|
||||||
|
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||||
|
AND W4A8_ARCHS)
|
||||||
|
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||||
|
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||||
|
"later if you intend on running w4a16 quantized models on "
|
||||||
|
"Hopper.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||||
|
"found in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
message(STATUS "Enabling C_stable extension.")
|
message(STATUS "Enabling C_stable extension.")
|
||||||
define_extension_target(
|
define_extension_target(
|
||||||
_C_stable_libtorch
|
_C_stable_libtorch
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/util/BFloat16.h>
|
#include <torch/headeronly/util/BFloat16.h>
|
||||||
#include <c10/util/Half.h>
|
#include <torch/headeronly/util/Half.h>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cute/tensor.hpp>
|
#include <cute/tensor.hpp>
|
||||||
#include <torch/all.h>
|
|
||||||
namespace cute {
|
namespace cute {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
@@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||||
|
|
||||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||||
@@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray {
|
|||||||
begin_loop(int epi_m, int epi_n) {
|
begin_loop(int epi_m, int epi_n) {
|
||||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray {
|
|||||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||||
using ThreadCount = decltype(size(args.tiled_copy));
|
using ThreadCount = decltype(size(args.tiled_copy));
|
||||||
|
|
||||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||||
//// G2S: Gmem to Smem
|
//// G2S: Gmem to Smem
|
||||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||||
@@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray {
|
|||||||
Stride<_0, _1>>{},
|
Stride<_0, _1>>{},
|
||||||
Layout<_1>{});
|
Layout<_1>{});
|
||||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||||
|
|
||||||
//// G2S: Coord
|
//// G2S: Coord
|
||||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||||
|
|
||||||
//// S2R: Smem to Reg
|
//// S2R: Smem to Reg
|
||||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||||
|
|
||||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||||
tGS_gRow,
|
tGS_gRow,
|
||||||
@@ -389,7 +389,7 @@ struct Sm90ColOrScalarBroadcastArray {
|
|||||||
|
|
||||||
CUTLASS_DEVICE void
|
CUTLASS_DEVICE void
|
||||||
begin() {
|
begin() {
|
||||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(pred); ++i) {
|
for (int i = 0; i < size(pred); ++i) {
|
||||||
pred(i) = get<0>(tCcCol(i)) < m;
|
pred(i) = get<0>(tCcCol(i)) < m;
|
||||||
@@ -409,7 +409,7 @@ struct Sm90ColOrScalarBroadcastArray {
|
|||||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||||
Array<Element, FragmentSize> frg_col;
|
Array<Element, FragmentSize> frg_col;
|
||||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < FragmentSize; ++i) {
|
for (int i = 0; i < FragmentSize; ++i) {
|
||||||
@@ -431,16 +431,16 @@ struct Sm90ColOrScalarBroadcastArray {
|
|||||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||||
|
|
||||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
|
||||||
// Generate an identity tensor matching the shape of the global tensor and
|
// Generate an identity tensor matching the shape of the global tensor and
|
||||||
// partition the same way, this will be used to generate the predicate
|
// partition the same way, this will be used to generate the predicate
|
||||||
// tensor for loading
|
// tensor for loading
|
||||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
cute::Tensor cCol = make_identity_tensor(mCol.shape());
|
||||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
|
||||||
return ConsumerStoreCallbacks(
|
return ConsumerStoreCallbacks(
|
||||||
|
|||||||
@@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||||
|
|
||||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||||
@@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast {
|
|||||||
begin_loop(int epi_m, int epi_n) {
|
begin_loop(int epi_m, int epi_n) {
|
||||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast {
|
|||||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||||
using ThreadCount = decltype(size(args.tiled_copy));
|
using ThreadCount = decltype(size(args.tiled_copy));
|
||||||
|
|
||||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
cute::Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||||
//// G2S: Gmem to Smem
|
//// G2S: Gmem to Smem
|
||||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||||
@@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast {
|
|||||||
Stride<_0, _1>>{},
|
Stride<_0, _1>>{},
|
||||||
Layout<_1>{});
|
Layout<_1>{});
|
||||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||||
|
|
||||||
//// G2S: Coord
|
//// G2S: Coord
|
||||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||||
|
|
||||||
//// S2R: Smem to Reg
|
//// S2R: Smem to Reg
|
||||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
cute::Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||||
|
|
||||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||||
tGS_gRow,
|
tGS_gRow,
|
||||||
@@ -382,7 +382,7 @@ struct Sm90ColOrScalarBroadcast {
|
|||||||
|
|
||||||
CUTLASS_DEVICE void
|
CUTLASS_DEVICE void
|
||||||
begin() {
|
begin() {
|
||||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
cute::Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < size(pred); ++i) {
|
for (int i = 0; i < size(pred); ++i) {
|
||||||
pred(i) = get<0>(tCcCol(i)) < m;
|
pred(i) = get<0>(tCcCol(i)) < m;
|
||||||
@@ -402,7 +402,7 @@ struct Sm90ColOrScalarBroadcast {
|
|||||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||||
Array<Element, FragmentSize> frg_col;
|
Array<Element, FragmentSize> frg_col;
|
||||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int i = 0; i < FragmentSize; ++i) {
|
for (int i = 0; i < FragmentSize; ++i) {
|
||||||
@@ -422,16 +422,16 @@ struct Sm90ColOrScalarBroadcast {
|
|||||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||||
|
|
||||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
||||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
cute::Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
|
||||||
// Generate an identity tensor matching the shape of the global tensor and
|
// Generate an identity tensor matching the shape of the global tensor and
|
||||||
// partition the same way, this will be used to generate the predicate
|
// partition the same way, this will be used to generate the predicate
|
||||||
// tensor for loading
|
// tensor for loading
|
||||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
cute::Tensor cCol = make_identity_tensor(mCol.shape());
|
||||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
cute::Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
|
||||||
return ConsumerStoreCallbacks(
|
return ConsumerStoreCallbacks(
|
||||||
|
|||||||
@@ -1,6 +1,21 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/all.h>
|
// This header is shared between _C (unstable ABI, used by machete) and
|
||||||
|
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
|
||||||
|
// is defined only for the stable target, so we switch includes and types
|
||||||
|
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
|
||||||
|
#ifdef TORCH_TARGET_VERSION
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include <torch/headeronly/util/BFloat16.h>
|
||||||
|
#include <torch/headeronly/util/Half.h>
|
||||||
|
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
|
||||||
|
using TorchTensor = torch::stable::Tensor;
|
||||||
|
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
|
||||||
|
#else
|
||||||
|
#include <torch/all.h>
|
||||||
|
using TorchTensor = torch::Tensor;
|
||||||
|
#define TORCH_UTILS_CHECK TORCH_CHECK
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "cute/layout.hpp"
|
#include "cute/layout.hpp"
|
||||||
#include "cutlass/layout/matrix.h"
|
#include "cutlass/layout/matrix.h"
|
||||||
@@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
|||||||
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
|
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||||
// strides are set to be 0 or 1.
|
// strides are set to be 0 or 1.
|
||||||
template <typename Stride>
|
template <typename Stride>
|
||||||
static inline auto make_cute_layout(torch::Tensor const& tensor,
|
static inline auto make_cute_layout(TorchTensor const& tensor,
|
||||||
std::string_view name = "tensor") {
|
std::string_view name = "tensor") {
|
||||||
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
|
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
|
||||||
auto stride = cute::transform_with_idx(
|
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
|
||||||
Stride{}, [&](auto const& stride_ele, auto const& idx) {
|
auto const& idx) {
|
||||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||||
|
|
||||||
if (idx < tensor.dim()) {
|
if (idx < tensor.dim()) {
|
||||||
if constexpr (cute::is_static_v<StrideEle>) {
|
if constexpr (cute::is_static_v<StrideEle>) {
|
||||||
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
|
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
|
||||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||||
return StrideEle{};
|
return StrideEle{};
|
||||||
} else {
|
} else {
|
||||||
if (tensor.size(idx) == 1) {
|
if (tensor.size(idx) == 1) {
|
||||||
// use 0 stride for dim with size 1, this is easier for
|
// use 0 stride for dim with size 1, this is easier for
|
||||||
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
// cute/cutlass to optimize (helps the TMA code flatten dims)
|
||||||
return StrideEle{0};
|
return StrideEle{0};
|
||||||
} else {
|
|
||||||
return tensor.stride(idx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Extra strides are assumed to be 0 or 1
|
return tensor.stride(idx);
|
||||||
if constexpr (cute::is_static_v<StrideEle>) {
|
|
||||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
|
||||||
}
|
|
||||||
return StrideEle{};
|
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
} else {
|
||||||
|
// Extra strides are assumed to be 0 or 1
|
||||||
|
if constexpr (cute::is_static_v<StrideEle>) {
|
||||||
|
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||||
|
}
|
||||||
|
return StrideEle{};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
||||||
if (idx < tensor.dim())
|
if (idx < tensor.dim())
|
||||||
@@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
|||||||
|
|
||||||
template <typename Stride>
|
template <typename Stride>
|
||||||
static inline auto maybe_make_cute_layout(
|
static inline auto maybe_make_cute_layout(
|
||||||
std::optional<torch::Tensor> const& tensor,
|
std::optional<TorchTensor> const& tensor,
|
||||||
std::string_view name = "tensor") {
|
std::string_view name = "tensor") {
|
||||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||||
|
|
||||||
@@ -121,12 +136,12 @@ template <typename T>
|
|||||||
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct equivalent_cutlass_type<c10::Half> {
|
struct equivalent_cutlass_type<torch::headeronly::Half> {
|
||||||
using type = cutlass::half_t;
|
using type = cutlass::half_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct equivalent_cutlass_type<c10::BFloat16> {
|
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
|
||||||
using type = cutlass::bfloat16_t;
|
using type = cutlass::bfloat16_t;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -134,8 +149,8 @@ struct equivalent_cutlass_type<c10::BFloat16> {
|
|||||||
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
||||||
//
|
//
|
||||||
|
|
||||||
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
|
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
|
||||||
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
|
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct equivalent_scalar_type {
|
struct equivalent_scalar_type {
|
||||||
using type = T;
|
using type = T;
|
||||||
@@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
|
|||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct equivalent_scalar_type<cutlass::half_t> {
|
struct equivalent_scalar_type<cutlass::half_t> {
|
||||||
using type = c10::Half;
|
using type = torch::headeronly::Half;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
||||||
using type = c10::BFloat16;
|
using type = torch::headeronly::BFloat16;
|
||||||
};
|
};
|
||||||
|
|
||||||
// get equivalent c10::ScalarType tag from compile time type
|
// get equivalent torch::headeronly::ScalarType tag from compile time type
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
|
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
|
||||||
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
|
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
|
||||||
|
|||||||
@@ -49,6 +49,15 @@
|
|||||||
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||||
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
VLLM_STABLE_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
// Half types dispatch (Half + BFloat16)
|
||||||
|
#define VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(...) \
|
||||||
|
THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_STABLE_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \
|
||||||
|
THO_DISPATCH_SWITCH(TYPE, NAME, \
|
||||||
|
VLLM_STABLE_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
// Boolean dispatch
|
// Boolean dispatch
|
||||||
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
|
#define VLLM_STABLE_DISPATCH_BOOL(expr, const_expr, ...) \
|
||||||
if (expr) { \
|
if (expr) { \
|
||||||
|
|||||||
@@ -84,4 +84,54 @@ void get_cutlass_batched_moe_mm_data(
|
|||||||
const torch::stable::Tensor& expert_num_tokens,
|
const torch::stable::Tensor& expert_num_tokens,
|
||||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||||
const int64_t k);
|
const int64_t k);
|
||||||
|
|
||||||
|
// FP4/NVFP4 ops
|
||||||
|
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||||
|
|
||||||
|
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
|
||||||
|
torch::stable::Tensor const& A,
|
||||||
|
torch::stable::Tensor const& B,
|
||||||
|
torch::stable::Tensor const& A_sf,
|
||||||
|
torch::stable::Tensor const& B_sf,
|
||||||
|
torch::stable::Tensor const& alpha);
|
||||||
|
|
||||||
|
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
|
||||||
|
const torch::stable::Tensor& a,
|
||||||
|
const torch::stable::Tensor& b,
|
||||||
|
const torch::stable::Tensor& a_blockscale,
|
||||||
|
const torch::stable::Tensor& b_blockscales,
|
||||||
|
const torch::stable::Tensor& alphas,
|
||||||
|
const torch::stable::Tensor& problem_sizes,
|
||||||
|
const torch::stable::Tensor& expert_offsets,
|
||||||
|
const torch::stable::Tensor& sf_offsets);
|
||||||
|
|
||||||
|
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
|
||||||
|
torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_scale, bool is_sf_swizzled_layout);
|
||||||
|
|
||||||
|
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_scale,
|
||||||
|
bool is_sf_swizzled_layout,
|
||||||
|
torch::stable::Tensor& output,
|
||||||
|
torch::stable::Tensor& output_scale);
|
||||||
|
|
||||||
|
void scaled_fp4_experts_quant(
|
||||||
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||||
|
torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_global_scale,
|
||||||
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||||
|
|
||||||
|
void silu_and_mul_scaled_fp4_experts_quant(
|
||||||
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||||
|
torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_global_scale,
|
||||||
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||||
|
|
||||||
|
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& out,
|
||||||
|
torch::stable::Tensor& output_block_scale,
|
||||||
|
torch::stable::Tensor& input,
|
||||||
|
torch::stable::Tensor& input_global_scale);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
#include <c10/cuda/CUDAStream.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
|
||||||
#include "core/scalar_type.hpp"
|
|
||||||
#include "cutlass/bfloat16.h"
|
#include "cutlass/bfloat16.h"
|
||||||
#include "cutlass/float8.h"
|
#include "cutlass/float8.h"
|
||||||
|
|
||||||
@@ -41,7 +40,7 @@ __global__ void get_group_gemm_starts(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||||
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
|
get_group_gemm_starts<cutlass::float_e4m3_t, int32_t, C_TYPE, float, \
|
||||||
cutlass::Array<cutlass::float_e4m3_t, 8>> \
|
cutlass::Array<cutlass::float_e4m3_t, 8>> \
|
||||||
<<<1, num_experts, 0, stream>>>( \
|
<<<1, num_experts, 0, stream>>>( \
|
||||||
@@ -66,23 +65,34 @@ __global__ void get_group_gemm_starts(
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void run_get_group_gemm_starts(
|
void run_get_group_gemm_starts(
|
||||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
|
||||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
|
||||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
|
||||||
torch::Tensor& b_group_scales_ptrs, torch::Tensor const& a_tensors,
|
torch::stable::Tensor& b_group_scales_ptrs,
|
||||||
torch::Tensor const& b_tensors, torch::Tensor& out_tensors,
|
torch::stable::Tensor const& a_tensors,
|
||||||
torch::Tensor const& a_scales, torch::Tensor const& b_scales,
|
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor const& b_group_scales, const int64_t b_group_size) {
|
torch::stable::Tensor const& a_scales,
|
||||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
torch::stable::Tensor const& b_scales,
|
||||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt32); // int4 8x packed into int32
|
torch::stable::Tensor const& b_group_scales, const int64_t b_group_size) {
|
||||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
STD_TORCH_CHECK(a_tensors.scalar_type() ==
|
||||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
TORCH_CHECK(b_group_scales.dtype() ==
|
STD_TORCH_CHECK(
|
||||||
torch::kFloat8_e4m3fn); // the underlying torch type is e4m3
|
b_tensors.scalar_type() ==
|
||||||
TORCH_CHECK(out_tensors.dtype() ==
|
torch::headeronly::ScalarType::Int); // int4 8x packed into int32
|
||||||
torch::kBFloat16); // only support bf16 for now
|
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float);
|
||||||
|
STD_TORCH_CHECK(
|
||||||
|
b_group_scales.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Float8_e4m3fn); // the underlying torch
|
||||||
|
// type is e4m3
|
||||||
|
STD_TORCH_CHECK(
|
||||||
|
out_tensors.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::BFloat16); // only support bf16 for now
|
||||||
// expect int64_t to avoid overflow during offset calculations
|
// expect int64_t to avoid overflow during offset calculations
|
||||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
|
||||||
|
torch::headeronly::ScalarType::Long);
|
||||||
|
|
||||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
// logical k, n
|
// logical k, n
|
||||||
@@ -90,15 +100,16 @@ void run_get_group_gemm_starts(
|
|||||||
int64_t k = a_tensors.size(1);
|
int64_t k = a_tensors.size(1);
|
||||||
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
|
int64_t scale_k = cutlass::ceil_div(k, b_group_size);
|
||||||
|
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||||
|
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
|
||||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
cutlass::bfloat16_t)
|
||||||
|
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -14,13 +14,12 @@
|
|||||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||||
|
|
||||||
// vllm includes
|
// vllm includes
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <torch/csrc/stable/library.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
#include <torch/all.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
#include "cutlass_extensions/torch_utils.hpp"
|
#include "cutlass_extensions/torch_utils.hpp"
|
||||||
#include "cutlass_extensions/common.hpp"
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
#include "core/registration.h"
|
|
||||||
#include "get_group_starts.cuh"
|
#include "get_group_starts.cuh"
|
||||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
#include "w4a8_utils.cuh"
|
#include "w4a8_utils.cuh"
|
||||||
@@ -168,31 +167,40 @@ struct W4A8GroupedGemmKernel {
|
|||||||
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
|
static_assert(sizeof(LayoutB_Reordered) % sizeof(int32_t) == 0,
|
||||||
"LayoutB_Reordered size must be divisible by 4 bytes");
|
"LayoutB_Reordered size must be divisible by 4 bytes");
|
||||||
|
|
||||||
static void grouped_mm(
|
static void grouped_mm(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
const torch::stable::Tensor& a_tensors,
|
||||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
const torch::stable::Tensor& b_tensors,
|
||||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
const torch::stable::Tensor& a_scales,
|
||||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
const torch::stable::Tensor& b_scales,
|
||||||
const torch::Tensor& problem_sizes_torch, const torch::Tensor& a_strides,
|
const torch::stable::Tensor& b_group_scales,
|
||||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
const int64_t b_group_size,
|
||||||
const torch::Tensor& group_scale_strides) {
|
const torch::stable::Tensor& expert_offsets,
|
||||||
|
const torch::stable::Tensor& problem_sizes_torch,
|
||||||
|
const torch::stable::Tensor& a_strides,
|
||||||
|
const torch::stable::Tensor& b_strides,
|
||||||
|
const torch::stable::Tensor& c_strides,
|
||||||
|
const torch::stable::Tensor& group_scale_strides) {
|
||||||
auto device = a_tensors.device();
|
auto device = a_tensors.device();
|
||||||
auto device_id = device.index();
|
auto device_id = device.index();
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device);
|
const torch::stable::accelerator::DeviceGuard device_guard(device_id);
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(device_id);
|
auto stream = get_current_cuda_stream(device_id);
|
||||||
|
|
||||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
int n = static_cast<int>(b_tensors.size(1));
|
int n = static_cast<int>(b_tensors.size(1));
|
||||||
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
|
int k = static_cast<int>(b_tensors.size(2)) * PackFactor;
|
||||||
|
|
||||||
auto options_int =
|
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||||
torch::TensorOptions().dtype(torch::kInt64).device(device);
|
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor out_ptrs = torch::stable::empty(
|
||||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
|
||||||
torch::Tensor b_group_scales_ptrs = torch::empty(num_experts, options_int);
|
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
|
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
|
||||||
|
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
|
torch::stable::Tensor b_group_scales_ptrs = torch::stable::empty(
|
||||||
|
num_experts, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||||
|
|
||||||
// get the correct offsets to pass to gemm
|
// get the correct offsets to pass to gemm
|
||||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||||
@@ -247,9 +255,9 @@ struct W4A8GroupedGemmKernel {
|
|||||||
|
|
||||||
// Allocate workspace
|
// Allocate workspace
|
||||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||||
torch::Tensor workspace =
|
torch::stable::Tensor workspace = torch::stable::empty(
|
||||||
torch::empty(workspace_size,
|
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
|
||||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
device);
|
||||||
|
|
||||||
// Run GEMM
|
// Run GEMM
|
||||||
GemmShuffled gemm;
|
GemmShuffled gemm;
|
||||||
@@ -294,14 +302,20 @@ using Kernel_256x128_2x1x1_Coop =
|
|||||||
using Kernel_128x256_2x1x1_Coop =
|
using Kernel_128x256_2x1x1_Coop =
|
||||||
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
W4A8GroupedGemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>, Coop, CoopEpi>;
|
||||||
|
|
||||||
void mm_dispatch(
|
void mm_dispatch(torch::stable::Tensor& out_tensors,
|
||||||
torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
const torch::stable::Tensor& a_tensors,
|
||||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
const torch::stable::Tensor& b_tensors,
|
||||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
const torch::stable::Tensor& a_scales,
|
||||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
const torch::stable::Tensor& b_scales,
|
||||||
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
const torch::stable::Tensor& b_group_scales,
|
||||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
const int64_t b_group_size,
|
||||||
const torch::Tensor& group_scale_strides, const std::string& schedule) {
|
const torch::stable::Tensor& expert_offsets,
|
||||||
|
const torch::stable::Tensor& problem_sizes,
|
||||||
|
const torch::stable::Tensor& a_strides,
|
||||||
|
const torch::stable::Tensor& b_strides,
|
||||||
|
const torch::stable::Tensor& c_strides,
|
||||||
|
const torch::stable::Tensor& group_scale_strides,
|
||||||
|
const std::string& schedule) {
|
||||||
if (schedule == "Kernel_128x16_1x1x1_Coop") {
|
if (schedule == "Kernel_128x16_1x1x1_Coop") {
|
||||||
Kernel_128x16_1x1x1_Coop::grouped_mm(
|
Kernel_128x16_1x1x1_Coop::grouped_mm(
|
||||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, b_group_scales,
|
||||||
@@ -358,18 +372,23 @@ void mm_dispatch(
|
|||||||
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
b_group_size, expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
c_strides, group_scale_strides);
|
c_strides, group_scale_strides);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false,
|
STD_TORCH_CHECK(false,
|
||||||
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
|
"cutlass_w4a8_moe_mm: unknown schedule string: ", schedule);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
void mm(torch::stable::Tensor& out_tensors,
|
||||||
const torch::Tensor& b_tensors, const torch::Tensor& a_scales,
|
const torch::stable::Tensor& a_tensors,
|
||||||
const torch::Tensor& b_scales, const torch::Tensor& b_group_scales,
|
const torch::stable::Tensor& b_tensors,
|
||||||
const int64_t b_group_size, const torch::Tensor& expert_offsets,
|
const torch::stable::Tensor& a_scales,
|
||||||
const torch::Tensor& problem_sizes, const torch::Tensor& a_strides,
|
const torch::stable::Tensor& b_scales,
|
||||||
const torch::Tensor& b_strides, const torch::Tensor& c_strides,
|
const torch::stable::Tensor& b_group_scales, const int64_t b_group_size,
|
||||||
const torch::Tensor& group_scale_strides,
|
const torch::stable::Tensor& expert_offsets,
|
||||||
|
const torch::stable::Tensor& problem_sizes,
|
||||||
|
const torch::stable::Tensor& a_strides,
|
||||||
|
const torch::stable::Tensor& b_strides,
|
||||||
|
const torch::stable::Tensor& c_strides,
|
||||||
|
const torch::stable::Tensor& group_scale_strides,
|
||||||
std::optional<std::string> maybe_schedule) {
|
std::optional<std::string> maybe_schedule) {
|
||||||
// user has specified a schedule
|
// user has specified a schedule
|
||||||
if (maybe_schedule) {
|
if (maybe_schedule) {
|
||||||
@@ -406,26 +425,27 @@ void mm(torch::Tensor& out_tensors, const torch::Tensor& a_tensors,
|
|||||||
a_strides, b_strides, c_strides, group_scale_strides, schedule);
|
a_strides, b_strides, c_strides, group_scale_strides, schedule);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
std::tuple<torch::stable::Tensor, torch::stable::Tensor>
|
||||||
torch::Tensor const& b_tensors) {
|
encode_and_reorder_int4b(torch::stable::Tensor const& b_tensors) {
|
||||||
TORCH_CHECK(b_tensors.dtype() == torch::kInt32);
|
STD_TORCH_CHECK(b_tensors.scalar_type() ==
|
||||||
TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
|
torch::headeronly::ScalarType::Int);
|
||||||
TORCH_CHECK(b_tensors.is_contiguous());
|
STD_TORCH_CHECK(b_tensors.dim() == 3); // (experts, n, k)
|
||||||
TORCH_CHECK(b_tensors.is_cuda());
|
STD_TORCH_CHECK(b_tensors.is_contiguous());
|
||||||
|
STD_TORCH_CHECK(b_tensors.is_cuda());
|
||||||
|
|
||||||
int n = static_cast<int>(b_tensors.size(1));
|
int n = static_cast<int>(b_tensors.size(1));
|
||||||
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
|
int k = static_cast<int>(b_tensors.size(2)) * PackFactor; // logical k
|
||||||
|
|
||||||
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
|
// CUTLASS reorder_tensor requires k % 256 == 0 and n % 16 == 0.
|
||||||
// These misalignments cause silent OOB unless run under Compute Sanitizer.
|
// These misalignments cause silent OOB unless run under Compute Sanitizer.
|
||||||
TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
|
STD_TORCH_CHECK(k % 256 == 0, "logical k must be divisible by 256");
|
||||||
TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
|
STD_TORCH_CHECK(n % 16 == 0, "n must be divisible by 16");
|
||||||
|
|
||||||
// we will store the layout to an int32 tensor;
|
// we will store the layout to an int32 tensor;
|
||||||
// this is the number of elements we need per layout
|
// this is the number of elements we need per layout
|
||||||
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
|
constexpr size_t layout_width = sizeof(LayoutB_Reordered) / sizeof(int32_t);
|
||||||
|
|
||||||
torch::Tensor b_tensors_packed = torch::empty_like(b_tensors);
|
torch::stable::Tensor b_tensors_packed = torch::stable::empty_like(b_tensors);
|
||||||
int num_experts = static_cast<int>(b_tensors.size(0));
|
int num_experts = static_cast<int>(b_tensors.size(0));
|
||||||
|
|
||||||
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
|
auto b_ptr = static_cast<QuantType const*>(b_tensors.const_data_ptr());
|
||||||
@@ -435,7 +455,7 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
|||||||
size_t num_int4_elems = 1ull * num_experts * n * k;
|
size_t num_int4_elems = 1ull * num_experts * n * k;
|
||||||
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
|
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(b_ptr, b_packed_ptr,
|
||||||
num_int4_elems);
|
num_int4_elems);
|
||||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||||
|
|
||||||
// construct the layout once; assumes each expert has the same layout
|
// construct the layout once; assumes each expert has the same layout
|
||||||
using LayoutType = LayoutB_Reordered;
|
using LayoutType = LayoutB_Reordered;
|
||||||
@@ -456,28 +476,28 @@ std::tuple<torch::Tensor, torch::Tensor> encode_and_reorder_int4b(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// save the packed layout to torch tensor so we can re-use it
|
// save the packed layout to torch tensor so we can re-use it
|
||||||
auto cpu_opts =
|
torch::stable::Tensor layout_cpu = torch::stable::empty(
|
||||||
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
|
{num_experts, layout_width}, torch::headeronly::ScalarType::Int,
|
||||||
torch::Tensor layout_cpu =
|
std::nullopt, torch::stable::Device(torch::stable::DeviceType::CPU));
|
||||||
torch::empty({num_experts, layout_width}, cpu_opts);
|
|
||||||
|
|
||||||
int32_t* layout_data = layout_cpu.data_ptr<int32_t>();
|
int32_t* layout_data = layout_cpu.mutable_data_ptr<int32_t>();
|
||||||
for (int i = 0; i < num_experts; ++i) {
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
std::memcpy(layout_data + i * layout_width, // dst (int32*)
|
std::memcpy(layout_data + i * layout_width, // dst (int32*)
|
||||||
&layout_B_reordered, // src (LayoutType*)
|
&layout_B_reordered, // src (LayoutType*)
|
||||||
sizeof(LayoutType)); // number of bytes
|
sizeof(LayoutType)); // number of bytes
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor packed_layout =
|
torch::stable::Tensor packed_layout =
|
||||||
layout_cpu.to(b_tensors.device(), /*non_blocking=*/false);
|
torch::stable::to(layout_cpu, b_tensors.device(),
|
||||||
|
/*non_blocking=*/false);
|
||||||
|
|
||||||
return {b_tensors_packed, packed_layout};
|
return {b_tensors_packed, packed_layout};
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||||
m.impl("cutlass_w4a8_moe_mm", &mm);
|
m.impl("cutlass_w4a8_moe_mm", TORCH_BOX(&mm));
|
||||||
m.impl("cutlass_encode_and_reorder_int4b_grouped", &encode_and_reorder_int4b);
|
m.impl("cutlass_encode_and_reorder_int4b_grouped",
|
||||||
|
TORCH_BOX(&encode_and_reorder_int4b));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm::cutlass_w4a8_moe
|
} // namespace vllm::cutlass_w4a8_moe
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
@@ -3,14 +3,12 @@
|
|||||||
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
|
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <torch/csrc/stable/library.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
#include <torch/all.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
#include "cutlass_extensions/torch_utils.hpp"
|
#include "cutlass_extensions/torch_utils.hpp"
|
||||||
#include "w4a8_utils.cuh"
|
#include "w4a8_utils.cuh"
|
||||||
|
|
||||||
#include "core/registration.h"
|
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
@@ -161,31 +159,31 @@ struct W4A8GemmKernel {
|
|||||||
using StrideD = typename GemmKernelShuffled::StrideD;
|
using StrideD = typename GemmKernelShuffled::StrideD;
|
||||||
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||||
|
|
||||||
static torch::Tensor mm(torch::Tensor const& A,
|
static torch::stable::Tensor mm(
|
||||||
torch::Tensor const& B, // already packed
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& group_scales, // already packed
|
torch::stable::Tensor const& B, // already packed
|
||||||
int64_t group_size,
|
torch::stable::Tensor const& group_scales, // already packed
|
||||||
torch::Tensor const& channel_scales,
|
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||||
torch::Tensor const& token_scales,
|
torch::stable::Tensor const& token_scales,
|
||||||
std::optional<at::ScalarType> const& maybe_out_type) {
|
std::optional<torch::headeronly::ScalarType> const& maybe_out_type) {
|
||||||
// TODO: param validation
|
// TODO: param validation
|
||||||
int m = A.size(0);
|
int m = A.size(0);
|
||||||
int k = A.size(1);
|
int k = A.size(1);
|
||||||
int n = B.size(1);
|
int n = B.size(1);
|
||||||
|
|
||||||
// safely cast group_size to int
|
// safely cast group_size to int
|
||||||
TORCH_CHECK(group_size > 0 && group_size <= std::numeric_limits<int>::max(),
|
STD_TORCH_CHECK(
|
||||||
"group_size out of supported range for int: ", group_size);
|
group_size > 0 && group_size <= std::numeric_limits<int>::max(),
|
||||||
|
"group_size out of supported range for int: ", group_size);
|
||||||
int const group_size_int = static_cast<int>(group_size);
|
int const group_size_int = static_cast<int>(group_size);
|
||||||
|
|
||||||
// Allocate output
|
// Allocate output
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
|
A.get_device_index());
|
||||||
auto device = A.device();
|
auto device = A.device();
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
auto stream = get_current_cuda_stream(device.index());
|
||||||
torch::Tensor D =
|
torch::stable::Tensor D = torch::stable::empty(
|
||||||
torch::empty({m, n}, torch::TensorOptions()
|
{m, n}, equivalent_scalar_type_v<ElementD>, std::nullopt, device);
|
||||||
.dtype(equivalent_scalar_type_v<ElementD>)
|
|
||||||
.device(device));
|
|
||||||
// prepare arg pointers
|
// prepare arg pointers
|
||||||
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
|
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
|
||||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||||
@@ -237,9 +235,9 @@ struct W4A8GemmKernel {
|
|||||||
|
|
||||||
// Workspace
|
// Workspace
|
||||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||||
torch::Tensor workspace =
|
torch::stable::Tensor workspace = torch::stable::empty(
|
||||||
torch::empty(workspace_size,
|
workspace_size, torch::headeronly::ScalarType::Byte, std::nullopt,
|
||||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
device);
|
||||||
|
|
||||||
// Run GEMM
|
// Run GEMM
|
||||||
GemmShuffled gemm;
|
GemmShuffled gemm;
|
||||||
@@ -269,14 +267,14 @@ using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
|
|||||||
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
|
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
|
||||||
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
|
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
|
||||||
|
|
||||||
torch::Tensor mm_dispatch(torch::Tensor const& A,
|
torch::stable::Tensor mm_dispatch(
|
||||||
torch::Tensor const& B, // already packed
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& group_scales, // already packed
|
torch::stable::Tensor const& B, // already packed
|
||||||
int64_t group_size,
|
torch::stable::Tensor const& group_scales, // already packed
|
||||||
torch::Tensor const& channel_scales,
|
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||||
torch::Tensor const& token_scales,
|
torch::stable::Tensor const& token_scales,
|
||||||
std::optional<at::ScalarType> const& maybe_out_type,
|
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
|
||||||
const std::string& schedule) {
|
const std::string& schedule) {
|
||||||
if (schedule == "256x128_1x1x1") {
|
if (schedule == "256x128_1x1x1") {
|
||||||
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
|
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||||
channel_scales, token_scales,
|
channel_scales, token_scales,
|
||||||
@@ -318,17 +316,18 @@ torch::Tensor mm_dispatch(torch::Tensor const& A,
|
|||||||
channel_scales, token_scales,
|
channel_scales, token_scales,
|
||||||
maybe_out_type);
|
maybe_out_type);
|
||||||
}
|
}
|
||||||
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
STD_TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor mm(torch::Tensor const& A,
|
torch::stable::Tensor mm(
|
||||||
torch::Tensor const& B, // already packed
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& group_scales, // already packed
|
torch::stable::Tensor const& B, // already packed
|
||||||
int64_t group_size, torch::Tensor const& channel_scales,
|
torch::stable::Tensor const& group_scales, // already packed
|
||||||
torch::Tensor const& token_scales,
|
int64_t group_size, torch::stable::Tensor const& channel_scales,
|
||||||
std::optional<at::ScalarType> const& maybe_out_type,
|
torch::stable::Tensor const& token_scales,
|
||||||
std::optional<std::string> maybe_schedule) {
|
std::optional<torch::headeronly::ScalarType> const& maybe_out_type,
|
||||||
|
std::optional<std::string> maybe_schedule) {
|
||||||
// requested a specific schedule
|
// requested a specific schedule
|
||||||
if (maybe_schedule) {
|
if (maybe_schedule) {
|
||||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||||
@@ -378,14 +377,15 @@ torch::Tensor mm(torch::Tensor const& A,
|
|||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// Pre-processing utils
|
// Pre-processing utils
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
torch::stable::Tensor pack_scale_fp8(torch::stable::Tensor const& scales) {
|
||||||
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
|
STD_TORCH_CHECK(scales.scalar_type() ==
|
||||||
TORCH_CHECK(scales.is_contiguous());
|
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||||
TORCH_CHECK(scales.is_cuda());
|
STD_TORCH_CHECK(scales.is_contiguous());
|
||||||
|
STD_TORCH_CHECK(scales.is_cuda());
|
||||||
|
|
||||||
auto packed_scales = torch::empty(
|
auto packed_scales =
|
||||||
{scales.numel() * ScalePackSize},
|
torch::stable::empty({scales.numel() * ScalePackSize},
|
||||||
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
|
scales.scalar_type(), std::nullopt, scales.device());
|
||||||
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
|
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
|
||||||
auto packed_scales_ptr =
|
auto packed_scales_ptr =
|
||||||
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
|
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
|
||||||
@@ -396,15 +396,16 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
|||||||
return packed_scales;
|
return packed_scales;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
torch::stable::Tensor encode_and_reorder_int4b(torch::stable::Tensor const& B) {
|
||||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
STD_TORCH_CHECK(B.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||||
TORCH_CHECK(B.dim() == 2);
|
STD_TORCH_CHECK(B.dim() == 2);
|
||||||
|
|
||||||
torch::Tensor B_packed = torch::empty_like(B);
|
torch::stable::Tensor B_packed = torch::stable::empty_like(B);
|
||||||
|
|
||||||
int k = B.size(0) * PackFactor; // logical k
|
int k = B.size(0) * PackFactor; // logical k
|
||||||
int n = B.size(1);
|
int n = B.size(1);
|
||||||
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
|
STD_TORCH_CHECK((n * k) % 32 == 0,
|
||||||
|
"need multiples of 32 int4s for 16B chunks");
|
||||||
|
|
||||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||||
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
||||||
@@ -415,16 +416,17 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
|||||||
|
|
||||||
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
|
bool ok = vllm::cutlass_w4a8_utils::unified_encode_int4b(B_ptr, B_packed_ptr,
|
||||||
n * k);
|
n * k);
|
||||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
STD_TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||||
|
|
||||||
return B_packed;
|
return B_packed;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||||
m.impl("cutlass_w4a8_mm", &mm);
|
m.impl("cutlass_w4a8_mm", TORCH_BOX(&mm));
|
||||||
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
|
m.impl("cutlass_pack_scale_fp8", TORCH_BOX(&pack_scale_fp8));
|
||||||
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
|
m.impl("cutlass_encode_and_reorder_int4b",
|
||||||
|
TORCH_BOX(&encode_and_reorder_int4b));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm::cutlass_w4a8
|
} // namespace vllm::cutlass_w4a8
|
||||||
@@ -14,16 +14,15 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
#include "libtorch_stable/dispatch_utils.h"
|
||||||
|
#include "cuda_vec_utils.cuh"
|
||||||
|
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#include "dispatch_utils.h"
|
|
||||||
|
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "launch_bounds_utils.h"
|
#include "launch_bounds_utils.h"
|
||||||
@@ -118,17 +117,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
void silu_and_mul_nvfp4_quant_sm1xxa(
|
||||||
torch::Tensor& output_sf,
|
torch::stable::Tensor& output, // [..., d]
|
||||||
torch::Tensor& input, // [..., 2 * d]
|
torch::stable::Tensor& output_sf,
|
||||||
torch::Tensor& input_sf) {
|
torch::stable::Tensor& input, // [..., 2 * d]
|
||||||
|
torch::stable::Tensor& input_sf) {
|
||||||
int32_t m = input.size(0);
|
int32_t m = input.size(0);
|
||||||
int32_t n = input.size(1) / 2;
|
int32_t n = input.size(1) / 2;
|
||||||
|
|
||||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
|
STD_TORCH_CHECK(
|
||||||
input.scalar_type() == at::ScalarType::BFloat16,
|
input.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||||
"Unsupported input data type for quantize_to_fp4.");
|
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||||
|
"Unsupported input data type for quantize_to_fp4.");
|
||||||
|
|
||||||
int multiProcessorCount =
|
int multiProcessorCount =
|
||||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||||
@@ -136,8 +137,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
|||||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
input.get_device_index());
|
||||||
|
auto stream = get_current_cuda_stream(input.get_device_index());
|
||||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||||
int const numBlocksPerSM =
|
int const numBlocksPerSM =
|
||||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||||
@@ -149,7 +151,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
|||||||
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||||
dim3 grid(grid_x, grid_y);
|
dim3 grid(grid_x, grid_y);
|
||||||
|
|
||||||
VLLM_DISPATCH_HALF_TYPES(
|
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||||
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
|
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
|
||||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||||
@@ -14,14 +14,12 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "core/registration.h"
|
#include <torch/csrc/stable/library.h>
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
|
||||||
#include <torch/all.h>
|
|
||||||
#include <cutlass/arch/arch.h>
|
#include <cutlass/arch/arch.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <c10/cuda/CUDAStream.h>
|
|
||||||
#include "cutlass_extensions/common.hpp"
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
#include "cute/tensor.hpp"
|
#include "cute/tensor.hpp"
|
||||||
@@ -122,7 +120,7 @@ __global__ void __get_group_gemm_starts(
|
|||||||
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
|
#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \
|
||||||
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
|
TENSOR_C_TYPE, C_TYPE, LayoutSFA, \
|
||||||
LayoutSFB, ScaleConfig) \
|
LayoutSFB, ScaleConfig) \
|
||||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||||
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
|
__get_group_gemm_starts<ELEMENT_AB_TYPE, C_TYPE, SF_TYPE, float, \
|
||||||
LayoutSFA, LayoutSFB, ScaleConfig> \
|
LayoutSFA, LayoutSFB, ScaleConfig> \
|
||||||
<<<1, num_experts, 0, stream>>>( \
|
<<<1, num_experts, 0, stream>>>( \
|
||||||
@@ -150,50 +148,64 @@ __global__ void __get_group_gemm_starts(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||||
void run_get_group_gemm_starts(
|
void run_get_group_gemm_starts(const torch::stable::Tensor& a_starts,
|
||||||
const torch::Tensor& a_starts, const torch::Tensor& b_starts,
|
const torch::stable::Tensor& b_starts,
|
||||||
const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts,
|
const torch::stable::Tensor& out_starts,
|
||||||
const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts,
|
const torch::stable::Tensor& a_scales_starts,
|
||||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
const torch::stable::Tensor& b_scales_starts,
|
||||||
const torch::Tensor& a_strides, const torch::Tensor& b_strides,
|
const torch::stable::Tensor& alpha_starts,
|
||||||
const torch::Tensor& c_strides, int64_t a_stride_val, int64_t b_stride_val,
|
const torch::stable::Tensor& layout_sfa,
|
||||||
int64_t c_stride_val,
|
const torch::stable::Tensor& layout_sfb,
|
||||||
/*these are used for their base addresses*/
|
const torch::stable::Tensor& a_strides,
|
||||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
const torch::stable::Tensor& b_strides,
|
||||||
torch::Tensor const& out_tensors, torch::Tensor const& a_scales,
|
const torch::stable::Tensor& c_strides,
|
||||||
torch::Tensor const& b_scales, torch::Tensor const& alphas,
|
int64_t a_stride_val, int64_t b_stride_val,
|
||||||
torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets,
|
int64_t c_stride_val,
|
||||||
torch::Tensor const& problem_sizes, int M, int N, int K) {
|
/*these are used for their base addresses*/
|
||||||
|
torch::stable::Tensor const& a_tensors,
|
||||||
|
torch::stable::Tensor const& b_tensors,
|
||||||
|
torch::stable::Tensor const& out_tensors,
|
||||||
|
torch::stable::Tensor const& a_scales,
|
||||||
|
torch::stable::Tensor const& b_scales,
|
||||||
|
torch::stable::Tensor const& alphas,
|
||||||
|
torch::stable::Tensor const& expert_offsets,
|
||||||
|
torch::stable::Tensor const& sf_offsets,
|
||||||
|
torch::stable::Tensor const& problem_sizes,
|
||||||
|
int M, int N, int K) {
|
||||||
int num_experts = (int)expert_offsets.size(0);
|
int num_experts = (int)expert_offsets.size(0);
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||||
|
|
||||||
TORCH_CHECK(out_tensors.size(1) == N,
|
STD_TORCH_CHECK(out_tensors.size(1) == N,
|
||||||
"Output tensor shape doesn't match expected shape");
|
"Output tensor shape doesn't match expected shape");
|
||||||
TORCH_CHECK(K / 2 == b_tensors.size(2),
|
STD_TORCH_CHECK(K / 2 == b_tensors.size(2),
|
||||||
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
"b_tensors(dim = 2) and a_tensors(dim = 1) trailing"
|
||||||
" dimension must match");
|
" dimension must match");
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
|
//(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB,
|
||||||
// ScaleConfig)
|
// ScaleConfig)
|
||||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(
|
||||||
cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16,
|
cutlass::float_e2m1_t, cutlass::float_ue4m3_t,
|
||||||
cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig)
|
torch::headeronly::ScalarType::BFloat16, cutlass::bfloat16_t, LayoutSFA,
|
||||||
|
LayoutSFB, ScaleConfig)
|
||||||
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
|
__CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t,
|
||||||
cutlass::float_ue4m3_t, torch::kFloat16,
|
cutlass::float_ue4m3_t,
|
||||||
half, LayoutSFA, LayoutSFB, ScaleConfig)
|
torch::headeronly::ScalarType::Half, half,
|
||||||
|
LayoutSFA, LayoutSFB, ScaleConfig)
|
||||||
else {
|
else {
|
||||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void run_fp4_blockwise_scaled_group_mm_sm100(
|
void run_fp4_blockwise_scaled_group_mm_sm100(
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
const torch::stable::Tensor& b_blockscales,
|
||||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
const torch::stable::Tensor& alphas,
|
||||||
int N, int K) {
|
const torch::stable::Tensor& problem_sizes,
|
||||||
|
const torch::stable::Tensor& expert_offsets,
|
||||||
|
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||||
using ProblemShape =
|
using ProblemShape =
|
||||||
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||||
using ElementType = cutlass::float_e2m1_t;
|
using ElementType = cutlass::float_e2m1_t;
|
||||||
@@ -272,20 +284,40 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
|||||||
|
|
||||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
auto options_int =
|
|
||||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
|
||||||
|
|
||||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor a_ptrs =
|
||||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
std::nullopt, a.device());
|
||||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor b_ptrs =
|
||||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
std::nullopt, a.device());
|
||||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
torch::stable::Tensor out_ptrs =
|
||||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
|
std::nullopt, a.device());
|
||||||
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
|
torch::stable::Tensor a_scales_ptrs =
|
||||||
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor b_scales_ptrs =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor alpha_ptrs =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor layout_sfa = torch::stable::empty(
|
||||||
|
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||||
|
a.device());
|
||||||
|
torch::stable::Tensor layout_sfb = torch::stable::empty(
|
||||||
|
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||||
|
a.device());
|
||||||
|
torch::stable::Tensor a_strides1 =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor b_strides1 =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor c_strides1 =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
|
||||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||||
@@ -308,7 +340,7 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
|||||||
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions;
|
||||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||||
hw_info.device_id = a.get_device();
|
hw_info.device_id = a.get_device_index();
|
||||||
static std::unordered_map<int, int> cached_sm_counts;
|
static std::unordered_map<int, int> cached_sm_counts;
|
||||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||||
cached_sm_counts[hw_info.device_id] =
|
cached_sm_counts[hw_info.device_id] =
|
||||||
@@ -350,32 +382,35 @@ void run_fp4_blockwise_scaled_group_mm_sm100(
|
|||||||
scheduler};
|
scheduler};
|
||||||
|
|
||||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||||
auto const workspace_options =
|
auto workspace =
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
std::nullopt, a.device());
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
|
||||||
|
|
||||||
auto can_implement_status = gemm_op.can_implement(args);
|
auto can_implement_status = gemm_op.can_implement(args);
|
||||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
STD_TORCH_CHECK(
|
||||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
can_implement_status == cutlass::Status::kSuccess,
|
||||||
|
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||||
|
|
||||||
// Run the GEMM
|
// Run the GEMM
|
||||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||||
"Failed to initialize GEMM: status=", (int)status,
|
"Failed to initialize GEMM: status=", (int)status,
|
||||||
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
" workspace_size=", workspace_size,
|
||||||
" M=", M, " N=", N, " K=", K);
|
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
|
||||||
|
|
||||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||||
}
|
}
|
||||||
|
|
||||||
void run_fp4_blockwise_scaled_group_mm_sm120(
|
void run_fp4_blockwise_scaled_group_mm_sm120(
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
const torch::stable::Tensor& b_blockscales,
|
||||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
const torch::stable::Tensor& alphas,
|
||||||
int N, int K) {
|
const torch::stable::Tensor& problem_sizes,
|
||||||
|
const torch::stable::Tensor& expert_offsets,
|
||||||
|
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||||
using ProblemShape =
|
using ProblemShape =
|
||||||
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
||||||
using ElementType = cutlass::float_e2m1_t;
|
using ElementType = cutlass::float_e2m1_t;
|
||||||
@@ -446,20 +481,40 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
|||||||
|
|
||||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
auto options_int =
|
|
||||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
|
||||||
|
|
||||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor a_ptrs =
|
||||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
std::nullopt, a.device());
|
||||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::Tensor b_ptrs =
|
||||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
std::nullopt, a.device());
|
||||||
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
torch::stable::Tensor out_ptrs =
|
||||||
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
torch::Tensor a_strides1 = torch::empty(num_experts, options_int);
|
std::nullopt, a.device());
|
||||||
torch::Tensor b_strides1 = torch::empty(num_experts, options_int);
|
torch::stable::Tensor a_scales_ptrs =
|
||||||
torch::Tensor c_strides1 = torch::empty(num_experts, options_int);
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor b_scales_ptrs =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor alpha_ptrs =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor layout_sfa = torch::stable::empty(
|
||||||
|
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||||
|
a.device());
|
||||||
|
torch::stable::Tensor layout_sfb = torch::stable::empty(
|
||||||
|
{num_experts, 5}, torch::headeronly::ScalarType::Long, std::nullopt,
|
||||||
|
a.device());
|
||||||
|
torch::stable::Tensor a_strides1 =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor b_strides1 =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
torch::stable::Tensor c_strides1 =
|
||||||
|
torch::stable::empty(num_experts, torch::headeronly::ScalarType::Long,
|
||||||
|
std::nullopt, a.device());
|
||||||
|
|
||||||
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
||||||
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
||||||
@@ -480,7 +535,7 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
|||||||
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
||||||
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
||||||
scheduler.raster_order = RasterOrderOptions::AlongM;
|
scheduler.raster_order = RasterOrderOptions::AlongM;
|
||||||
hw_info.device_id = a.get_device();
|
hw_info.device_id = a.get_device_index();
|
||||||
static std::unordered_map<int, int> cached_sm_counts;
|
static std::unordered_map<int, int> cached_sm_counts;
|
||||||
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
||||||
cached_sm_counts[hw_info.device_id] =
|
cached_sm_counts[hw_info.device_id] =
|
||||||
@@ -523,33 +578,36 @@ void run_fp4_blockwise_scaled_group_mm_sm120(
|
|||||||
scheduler};
|
scheduler};
|
||||||
|
|
||||||
size_t workspace_size = Gemm::get_workspace_size(args);
|
size_t workspace_size = Gemm::get_workspace_size(args);
|
||||||
auto const workspace_options =
|
auto workspace =
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
std::nullopt, a.device());
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
const cudaStream_t stream = get_current_cuda_stream(a.get_device_index());
|
||||||
|
|
||||||
auto can_implement_status = gemm_op.can_implement(args);
|
auto can_implement_status = gemm_op.can_implement(args);
|
||||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
STD_TORCH_CHECK(
|
||||||
"Failed to implement GEMM: status=", (int)can_implement_status);
|
can_implement_status == cutlass::Status::kSuccess,
|
||||||
|
"Failed to implement GEMM: status=", (int)can_implement_status);
|
||||||
|
|
||||||
// Run the GEMM
|
// Run the GEMM
|
||||||
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
STD_TORCH_CHECK(status == cutlass::Status::kSuccess,
|
||||||
"Failed to initialize GEMM: status=", (int)status,
|
"Failed to initialize GEMM: status=", (int)status,
|
||||||
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
" workspace_size=", workspace_size,
|
||||||
" M=", M, " N=", N, " K=", K);
|
" num_experts=", num_experts, " M=", M, " N=", N, " K=", K);
|
||||||
|
|
||||||
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void run_fp4_blockwise_scaled_group_mm(
|
void run_fp4_blockwise_scaled_group_mm(
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
torch::stable::Tensor& output, const torch::stable::Tensor& a,
|
||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
const torch::stable::Tensor& b, const torch::stable::Tensor& a_blockscale,
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
const torch::stable::Tensor& b_blockscales,
|
||||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
const torch::stable::Tensor& alphas,
|
||||||
int N, int K) {
|
const torch::stable::Tensor& problem_sizes,
|
||||||
|
const torch::stable::Tensor& expert_offsets,
|
||||||
|
const torch::stable::Tensor& sf_offsets, int M, int N, int K) {
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||||
if (version_num >= 120 && version_num < 130) {
|
if (version_num >= 120 && version_num < 130) {
|
||||||
@@ -567,7 +625,7 @@ void run_fp4_blockwise_scaled_group_mm(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
false,
|
false,
|
||||||
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
|
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
|
||||||
version_num, ". Required capability: 100 or 120");
|
version_num, ". Required capability: 100 or 120");
|
||||||
@@ -575,26 +633,31 @@ void run_fp4_blockwise_scaled_group_mm(
|
|||||||
|
|
||||||
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
||||||
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
||||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define CHECK_TYPE(x, st, m) \
|
#define CHECK_TYPE(x, st, m) \
|
||||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||||
|
": Inconsistency of torch::stable::Tensor type:", m)
|
||||||
#define CHECK_TH_CUDA(x, m) \
|
#define CHECK_TH_CUDA(x, m) \
|
||||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.")
|
||||||
#define CHECK_CONTIGUOUS(x, m) \
|
#define CHECK_CONTIGUOUS(x, m) \
|
||||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.")
|
||||||
#define CHECK_INPUT(x, st, m) \
|
#define CHECK_INPUT(x, st, m) \
|
||||||
CHECK_TH_CUDA(x, m); \
|
CHECK_TH_CUDA(x, m); \
|
||||||
CHECK_CONTIGUOUS(x, m); \
|
CHECK_CONTIGUOUS(x, m); \
|
||||||
CHECK_TYPE(x, st, m)
|
CHECK_TYPE(x, st, m)
|
||||||
|
|
||||||
void cutlass_fp4_group_mm(
|
void cutlass_fp4_group_mm(torch::stable::Tensor& output,
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
const torch::stable::Tensor& a,
|
||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
const torch::stable::Tensor& b,
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
const torch::stable::Tensor& a_blockscale,
|
||||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
|
const torch::stable::Tensor& b_blockscales,
|
||||||
|
const torch::stable::Tensor& alphas,
|
||||||
|
const torch::stable::Tensor& problem_sizes,
|
||||||
|
const torch::stable::Tensor& expert_offsets,
|
||||||
|
const torch::stable::Tensor& sf_offsets) {
|
||||||
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
||||||
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
||||||
// Input validation
|
// Input validation
|
||||||
@@ -602,30 +665,34 @@ void cutlass_fp4_group_mm(
|
|||||||
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
||||||
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
|
CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale");
|
||||||
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
|
CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales");
|
||||||
CHECK_INPUT(alphas, at::ScalarType::Float, "alphas");
|
CHECK_INPUT(alphas, torch::headeronly::ScalarType::Float, "alphas");
|
||||||
|
|
||||||
TORCH_CHECK(a_blockscale.dim() == 2,
|
STD_TORCH_CHECK(
|
||||||
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
a_blockscale.dim() == 2,
|
||||||
" k // group_size], observed rank: ",
|
"expected a_blockscale to be of shape [num_experts, rounded_m,"
|
||||||
a_blockscale.dim())
|
" k // group_size], observed rank: ",
|
||||||
TORCH_CHECK(b_blockscales.dim() == 3,
|
a_blockscale.dim())
|
||||||
"expected b_blockscale to be of shape: "
|
STD_TORCH_CHECK(b_blockscales.dim() == 3,
|
||||||
" [num_experts, n, k // group_size], observed rank: ",
|
"expected b_blockscale to be of shape: "
|
||||||
b_blockscales.dim())
|
" [num_experts, n, k // group_size], observed rank: ",
|
||||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor");
|
b_blockscales.dim())
|
||||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
STD_TORCH_CHECK(problem_sizes.dim() == 2,
|
||||||
"problem_sizes must have the shape (num_experts, 3)");
|
"problem_sizes must be a 2D tensor");
|
||||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
STD_TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||||
"Number of experts in problem_sizes must match expert_offsets");
|
"problem_sizes must have the shape (num_experts, 3)");
|
||||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
STD_TORCH_CHECK(
|
||||||
"problem_sizes must be int32.");
|
problem_sizes.size(0) == expert_offsets.size(0),
|
||||||
|
"Number of experts in problem_sizes must match expert_offsets");
|
||||||
|
STD_TORCH_CHECK(
|
||||||
|
problem_sizes.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||||
|
"problem_sizes must be int32.");
|
||||||
|
|
||||||
int M = static_cast<int>(a.size(0));
|
int M = static_cast<int>(a.size(0));
|
||||||
int N = static_cast<int>(b.size(1));
|
int N = static_cast<int>(b.size(1));
|
||||||
int E = static_cast<int>(b.size(0));
|
int E = static_cast<int>(b.size(0));
|
||||||
int K = static_cast<int>(2 * b.size(2));
|
int K = static_cast<int>(2 * b.size(2));
|
||||||
|
|
||||||
if (output.scalar_type() == torch::kBFloat16) {
|
if (output.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||||
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
|
run_fp4_blockwise_scaled_group_mm<cutlass::bfloat16_t>(
|
||||||
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
||||||
expert_offsets, sf_offsets, M, N, K);
|
expert_offsets, sf_offsets, M, N, K);
|
||||||
@@ -633,7 +700,7 @@ void cutlass_fp4_group_mm(
|
|||||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
if (version_num >= 120 && version_num < 130) {
|
if (version_num >= 120 && version_num < 130) {
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
|
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
|
||||||
output.scalar_type());
|
output.scalar_type());
|
||||||
}
|
}
|
||||||
@@ -643,7 +710,7 @@ void cutlass_fp4_group_mm(
|
|||||||
expert_offsets, sf_offsets, M, N, K);
|
expert_offsets, sf_offsets, M, N, K);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
false,
|
false,
|
||||||
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
||||||
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
|
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
|
||||||
@@ -651,6 +718,6 @@ void cutlass_fp4_group_mm(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, m) {
|
||||||
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
|
m.impl("cutlass_fp4_group_mm", TORCH_BOX(&cutlass_fp4_group_mm));
|
||||||
}
|
}
|
||||||
@@ -14,16 +14,15 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
#include "libtorch_stable/dispatch_utils.h"
|
||||||
|
#include "cuda_vec_utils.cuh"
|
||||||
|
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#include "dispatch_utils.h"
|
|
||||||
|
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "nvfp4_utils.cuh"
|
#include "nvfp4_utils.cuh"
|
||||||
@@ -327,25 +326,28 @@ void quant_impl(void* output, void* output_scale, void* input,
|
|||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
/*Quantization entry for fp4 experts quantization*/
|
/*Quantization entry for fp4 experts quantization*/
|
||||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
#define CHECK_TH_CUDA(x, m) \
|
||||||
|
STD_TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||||
#define CHECK_CONTIGUOUS(x, m) \
|
#define CHECK_CONTIGUOUS(x, m) \
|
||||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
STD_TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||||
#define CHECK_INPUT(x, m) \
|
#define CHECK_INPUT(x, m) \
|
||||||
CHECK_TH_CUDA(x, m); \
|
CHECK_TH_CUDA(x, m); \
|
||||||
CHECK_CONTIGUOUS(x, m);
|
CHECK_CONTIGUOUS(x, m);
|
||||||
|
|
||||||
constexpr auto HALF = at::ScalarType::Half;
|
constexpr auto HALF = torch::headeronly::ScalarType::Half;
|
||||||
constexpr auto BF16 = at::ScalarType::BFloat16;
|
constexpr auto BF16 = torch::headeronly::ScalarType::BFloat16;
|
||||||
constexpr auto FLOAT = at::ScalarType::Float;
|
constexpr auto FLOAT = torch::headeronly::ScalarType::Float;
|
||||||
constexpr auto INT = at::ScalarType::Int;
|
constexpr auto INT = torch::headeronly::ScalarType::Int;
|
||||||
constexpr auto UINT8 = at::ScalarType::Byte;
|
constexpr auto UINT8 = torch::headeronly::ScalarType::Byte;
|
||||||
|
|
||||||
// Common validation for fp4 experts quantization entry points.
|
// Common validation for fp4 experts quantization entry points.
|
||||||
static void validate_fp4_experts_quant_inputs(
|
static void validate_fp4_experts_quant_inputs(
|
||||||
torch::Tensor const& output, torch::Tensor const& output_scale,
|
torch::stable::Tensor const& output,
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
torch::stable::Tensor const& output_scale,
|
||||||
torch::Tensor const& input_offset_by_experts,
|
torch::stable::Tensor const& input,
|
||||||
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
torch::stable::Tensor const& input_global_scale,
|
||||||
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
|
||||||
int64_t k) {
|
int64_t k) {
|
||||||
CHECK_INPUT(output, "output");
|
CHECK_INPUT(output, "output");
|
||||||
CHECK_INPUT(output_scale, "output_scale");
|
CHECK_INPUT(output_scale, "output_scale");
|
||||||
@@ -354,41 +356,42 @@ static void validate_fp4_experts_quant_inputs(
|
|||||||
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
|
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
|
||||||
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
|
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
|
||||||
|
|
||||||
TORCH_CHECK(output.dim() == 2);
|
STD_TORCH_CHECK(output.dim() == 2);
|
||||||
TORCH_CHECK(output_scale.dim() == 2);
|
STD_TORCH_CHECK(output_scale.dim() == 2);
|
||||||
TORCH_CHECK(input.dim() == 2);
|
STD_TORCH_CHECK(input.dim() == 2);
|
||||||
TORCH_CHECK(input_global_scale.dim() == 1);
|
STD_TORCH_CHECK(input_global_scale.dim() == 1);
|
||||||
TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
STD_TORCH_CHECK(input_offset_by_experts.dim() == 1);
|
||||||
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
STD_TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);
|
||||||
|
|
||||||
TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
STD_TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
|
||||||
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
STD_TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
|
||||||
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
STD_TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
|
||||||
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
STD_TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
|
||||||
// output is uint8 (two nvfp4 values are packed into one uint8)
|
// output is uint8 (two nvfp4 values are packed into one uint8)
|
||||||
// output_scale is int32 (four fp8 values are packed into one int32)
|
// output_scale is int32 (four fp8 values are packed into one int32)
|
||||||
TORCH_CHECK(output.scalar_type() == UINT8);
|
STD_TORCH_CHECK(output.scalar_type() == UINT8);
|
||||||
TORCH_CHECK(output_scale.scalar_type() == INT);
|
STD_TORCH_CHECK(output_scale.scalar_type() == INT);
|
||||||
|
|
||||||
const int BLOCK_SIZE = 16;
|
const int BLOCK_SIZE = 16;
|
||||||
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
STD_TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
|
||||||
auto n_experts = input_global_scale.size(0);
|
auto n_experts = input_global_scale.size(0);
|
||||||
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
STD_TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
|
||||||
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
STD_TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
|
||||||
TORCH_CHECK(output.size(0) == m_topk);
|
STD_TORCH_CHECK(output.size(0) == m_topk);
|
||||||
TORCH_CHECK(output.size(1) == k / 2);
|
STD_TORCH_CHECK(output.size(1) == k / 2);
|
||||||
int scales_k = k / BLOCK_SIZE;
|
int scales_k = k / BLOCK_SIZE;
|
||||||
// 4 means the swizzle requirement by nvidia nvfp4.
|
// 4 means the swizzle requirement by nvidia nvfp4.
|
||||||
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
|
||||||
// 4 means 4 fp8 values are packed into one int32
|
// 4 means 4 fp8 values are packed into one int32
|
||||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
STD_TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void scaled_fp4_experts_quant_sm1xxa(
|
void scaled_fp4_experts_quant_sm1xxa(
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
torch::stable::Tensor const& input,
|
||||||
torch::Tensor const& input_offset_by_experts,
|
torch::stable::Tensor const& input_global_scale,
|
||||||
torch::Tensor const& output_scale_offset_by_experts) {
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||||
auto m_topk = input.size(0);
|
auto m_topk = input.size(0);
|
||||||
auto k = input.size(1);
|
auto k = input.size(1);
|
||||||
|
|
||||||
@@ -397,11 +400,11 @@ void scaled_fp4_experts_quant_sm1xxa(
|
|||||||
output_scale_offset_by_experts, m_topk, k);
|
output_scale_offset_by_experts, m_topk, k);
|
||||||
|
|
||||||
auto n_experts = input_global_scale.size(0);
|
auto n_experts = input_global_scale.size(0);
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
const cudaStream_t stream =
|
input.get_device_index());
|
||||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
|
||||||
|
|
||||||
VLLM_DISPATCH_HALF_TYPES(
|
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||||
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
|
||||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
|
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
|
||||||
@@ -413,14 +416,15 @@ void scaled_fp4_experts_quant_sm1xxa(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
torch::stable::Tensor const& input,
|
||||||
torch::Tensor const& input_offset_by_experts,
|
torch::stable::Tensor const& input_global_scale,
|
||||||
torch::Tensor const& output_scale_offset_by_experts) {
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||||
auto m_topk = input.size(0);
|
auto m_topk = input.size(0);
|
||||||
// Input has gate || up layout, so k = input.size(1) / 2
|
// Input has gate || up layout, so k = input.size(1) / 2
|
||||||
auto k_times_2 = input.size(1);
|
auto k_times_2 = input.size(1);
|
||||||
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
STD_TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
|
||||||
auto k = k_times_2 / 2;
|
auto k = k_times_2 / 2;
|
||||||
|
|
||||||
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
validate_fp4_experts_quant_inputs(output, output_scale, input,
|
||||||
@@ -428,11 +432,11 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
|||||||
output_scale_offset_by_experts, m_topk, k);
|
output_scale_offset_by_experts, m_topk, k);
|
||||||
|
|
||||||
auto n_experts = input_global_scale.size(0);
|
auto n_experts = input_global_scale.size(0);
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
const cudaStream_t stream =
|
input.get_device_index());
|
||||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
const cudaStream_t stream = get_current_cuda_stream(input.get_device_index());
|
||||||
|
|
||||||
VLLM_DISPATCH_HALF_TYPES(
|
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||||
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
|
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
|
||||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||||
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
|
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
|
||||||
172
csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
Normal file
172
csrc/libtorch_stable/quantization/fp4/nvfp4_quant_entry.cu
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
#include "nvfp4_utils.cuh"
|
||||||
|
|
||||||
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
|
||||||
|
torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& output_sf,
|
||||||
|
torch::stable::Tensor const& input_sf,
|
||||||
|
bool is_sf_swizzled_layout);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
void scaled_fp4_experts_quant_sm1xxa(
|
||||||
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||||
|
torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_global_scale,
|
||||||
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
void silu_and_mul_nvfp4_quant_sm1xxa(torch::stable::Tensor& output,
|
||||||
|
torch::stable::Tensor& output_sf,
|
||||||
|
torch::stable::Tensor& input,
|
||||||
|
torch::stable::Tensor& input_sf);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||||
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||||
|
torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_global_scale,
|
||||||
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static bool nvfp4_quant_sm_supported() {
|
||||||
|
const int32_t sm = get_sm_version_num();
|
||||||
|
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||||
|
if (sm >= 100 && sm < 120) return true;
|
||||||
|
#endif
|
||||||
|
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
||||||
|
if (sm >= 120 && sm < 130) return true;
|
||||||
|
#endif
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void scaled_fp4_quant_out(torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_sf,
|
||||||
|
bool is_sf_swizzled_layout,
|
||||||
|
torch::stable::Tensor& output,
|
||||||
|
torch::stable::Tensor& output_sf) {
|
||||||
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||||
|
"No compiled nvfp4 quantization kernel for SM ",
|
||||||
|
get_sm_version_num(),
|
||||||
|
". Recompile with the appropriate CUDA arch.");
|
||||||
|
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
||||||
|
is_sf_swizzled_layout);
|
||||||
|
#endif
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
|
"No compiled nvfp4 quantization kernel");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<torch::stable::Tensor, torch::stable::Tensor> scaled_fp4_quant_func(
|
||||||
|
torch::stable::Tensor const& input, torch::stable::Tensor const& input_sf,
|
||||||
|
bool is_sf_swizzled_layout) {
|
||||||
|
int64_t n = input.size(-1);
|
||||||
|
int64_t m = input.numel() / n;
|
||||||
|
auto device = input.device();
|
||||||
|
|
||||||
|
// Two fp4 values packed into a uint8
|
||||||
|
auto output = torch::stable::empty(
|
||||||
|
{m, n / 2}, torch::headeronly::ScalarType::Byte, std::nullopt, device);
|
||||||
|
|
||||||
|
torch::stable::Tensor output_sf;
|
||||||
|
if (is_sf_swizzled_layout) {
|
||||||
|
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
|
||||||
|
output_sf = torch::stable::empty(
|
||||||
|
{sf_m, sf_n}, torch::headeronly::ScalarType::Int, std::nullopt, device);
|
||||||
|
} else {
|
||||||
|
output_sf = torch::stable::empty({m, n / CVT_FP4_SF_VEC_SIZE},
|
||||||
|
torch::headeronly::ScalarType::Byte,
|
||||||
|
std::nullopt, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
|
||||||
|
output_sf);
|
||||||
|
return {output, output_sf};
|
||||||
|
}
|
||||||
|
|
||||||
|
void scaled_fp4_experts_quant(
|
||||||
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||||
|
torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_global_scale,
|
||||||
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||||
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||||
|
"No compiled nvfp4 experts quantization kernel for SM ",
|
||||||
|
get_sm_version_num(),
|
||||||
|
". Recompile with the appropriate CUDA arch.");
|
||||||
|
return scaled_fp4_experts_quant_sm1xxa(
|
||||||
|
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||||
|
output_scale_offset_by_experts);
|
||||||
|
#endif
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false, "No compiled nvfp4 experts quantization kernel");
|
||||||
|
}
|
||||||
|
|
||||||
|
void silu_and_mul_nvfp4_quant(torch::stable::Tensor& output,
|
||||||
|
torch::stable::Tensor& output_sf,
|
||||||
|
torch::stable::Tensor& input,
|
||||||
|
torch::stable::Tensor& input_sf) {
|
||||||
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||||
|
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
|
||||||
|
get_sm_version_num(),
|
||||||
|
". Recompile with the appropriate CUDA arch.");
|
||||||
|
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
||||||
|
#endif
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false, "No compiled silu_and_mul nvfp4 quantization kernel");
|
||||||
|
}
|
||||||
|
|
||||||
|
void silu_and_mul_scaled_fp4_experts_quant(
|
||||||
|
torch::stable::Tensor& output, torch::stable::Tensor& output_scale,
|
||||||
|
torch::stable::Tensor const& input,
|
||||||
|
torch::stable::Tensor const& input_global_scale,
|
||||||
|
torch::stable::Tensor const& input_offset_by_experts,
|
||||||
|
torch::stable::Tensor const& output_scale_offset_by_experts) {
|
||||||
|
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
||||||
|
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
||||||
|
STD_TORCH_CHECK(nvfp4_quant_sm_supported(),
|
||||||
|
"No compiled silu_and_mul nvfp4 experts quantization kernel "
|
||||||
|
"for SM ",
|
||||||
|
get_sm_version_num(),
|
||||||
|
". Recompile with the appropriate CUDA arch.");
|
||||||
|
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
||||||
|
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
||||||
|
output_scale_offset_by_experts);
|
||||||
|
#endif
|
||||||
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
|
||||||
|
}
|
||||||
@@ -14,16 +14,16 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#include "dispatch_utils.h"
|
|
||||||
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
#include "libtorch_stable/dispatch_utils.h"
|
||||||
|
#include "cuda_vec_utils.cuh"
|
||||||
|
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "launch_bounds_utils.h"
|
#include "launch_bounds_utils.h"
|
||||||
@@ -173,18 +173,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
void scaled_fp4_quant_sm1xxa(torch::stable::Tensor const& output,
|
||||||
torch::Tensor const& input,
|
torch::stable::Tensor const& input,
|
||||||
torch::Tensor const& output_sf,
|
torch::stable::Tensor const& output_sf,
|
||||||
torch::Tensor const& input_sf,
|
torch::stable::Tensor const& input_sf,
|
||||||
bool is_sf_swizzled_layout) {
|
bool is_sf_swizzled_layout) {
|
||||||
int32_t m = input.size(0);
|
int32_t m = input.size(0);
|
||||||
int32_t n = input.size(1);
|
int32_t n = input.size(1);
|
||||||
|
|
||||||
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
STD_TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
|
STD_TORCH_CHECK(
|
||||||
input.scalar_type() == at::ScalarType::BFloat16,
|
input.scalar_type() == torch::headeronly::ScalarType::Half ||
|
||||||
"Unsupported input data type for quantize_to_fp4.");
|
input.scalar_type() == torch::headeronly::ScalarType::BFloat16,
|
||||||
|
"Unsupported input data type for quantize_to_fp4.");
|
||||||
|
|
||||||
int multiProcessorCount =
|
int multiProcessorCount =
|
||||||
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||||
@@ -192,8 +193,9 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
|||||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
input.get_device_index());
|
||||||
|
auto stream = get_current_cuda_stream(input.get_device_index());
|
||||||
|
|
||||||
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
|
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
|
||||||
|
|
||||||
@@ -213,15 +215,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
|||||||
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||||
dim3 grid(grid_x, grid_y);
|
dim3 grid(grid_x, grid_y);
|
||||||
|
|
||||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||||
// NOTE: We don't support e8m0 scales at this moment.
|
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||||
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
|
||||||
m, n, num_padded_cols, input_ptr, input_sf_ptr,
|
m, n, num_padded_cols, input_ptr, input_sf_ptr,
|
||||||
reinterpret_cast<uint32_t*>(output_ptr),
|
reinterpret_cast<uint32_t*>(output_ptr),
|
||||||
reinterpret_cast<uint32_t*>(sf_out));
|
reinterpret_cast<uint32_t*>(sf_out));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
|
int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
|
||||||
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
|
int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
|
||||||
@@ -229,15 +231,15 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
|||||||
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
|
||||||
dim3 grid(grid_x, grid_y);
|
dim3 grid(grid_x, grid_y);
|
||||||
|
|
||||||
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
VLLM_STABLE_DISPATCH_HALF_TYPES(
|
||||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
input.scalar_type(), "nvfp4_quant_kernel", [&] {
|
||||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||||
// NOTE: We don't support e8m0 scales at this moment.
|
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||||
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
|
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
|
||||||
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
|
<<<grid, block, 0, stream>>>(
|
||||||
input_ptr, input_sf_ptr,
|
m, n, sf_n_unpadded, num_packed_cols, input_ptr, input_sf_ptr,
|
||||||
reinterpret_cast<uint32_t*>(output_ptr),
|
reinterpret_cast<uint32_t*>(output_ptr),
|
||||||
reinterpret_cast<uint32_t*>(sf_out));
|
reinterpret_cast<uint32_t*>(sf_out));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -14,32 +14,39 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
#include "libtorch_stable/torch_utils.h"
|
||||||
|
|
||||||
#include "cutlass_extensions/common.hpp"
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
||||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
|
||||||
torch::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
torch::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha);
|
torch::stable::Tensor const& B_sf,
|
||||||
|
torch::stable::Tensor const& alpha);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
||||||
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
|
||||||
torch::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
torch::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha);
|
torch::stable::Tensor const& B_sf,
|
||||||
|
torch::stable::Tensor const& alpha);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
|
void cutlass_scaled_fp4_mm(torch::stable::Tensor& D,
|
||||||
const torch::Tensor& B, const torch::Tensor& A_sf,
|
const torch::stable::Tensor& A,
|
||||||
const torch::Tensor& B_sf,
|
const torch::stable::Tensor& B,
|
||||||
const torch::Tensor& alpha) {
|
const torch::stable::Tensor& A_sf,
|
||||||
// Make sure we’re on A’s device.
|
const torch::stable::Tensor& B_sf,
|
||||||
const c10::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const torch::stable::Tensor& alpha) {
|
||||||
|
// Make sure we're on A's device.
|
||||||
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
|
A.get_device_index());
|
||||||
const int32_t sm = get_sm_version_num();
|
const int32_t sm = get_sm_version_num();
|
||||||
|
|
||||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
||||||
@@ -56,8 +63,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A,
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel for SM ", sm,
|
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
". Recompile with CUDA >= 12.8 and CC >= 100.");
|
false, "No compiled nvfp4 mm kernel for SM ", sm,
|
||||||
|
". Recompile with CUDA >= 12.8 and CC >= 100.");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
||||||
@@ -14,10 +14,9 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include "cutlass_extensions/common.hpp"
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
@@ -127,8 +126,9 @@ struct Fp4GemmSm100 {
|
|||||||
|
|
||||||
template <typename Config>
|
template <typename Config>
|
||||||
typename Config::Gemm::Arguments args_from_options(
|
typename Config::Gemm::Arguments args_from_options(
|
||||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||||
|
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
|
||||||
int64_t M, int64_t N, int64_t K) {
|
int64_t M, int64_t N, int64_t K) {
|
||||||
using ElementA = typename Config::Gemm::ElementA;
|
using ElementA = typename Config::Gemm::ElementA;
|
||||||
using ElementB = typename Config::Gemm::ElementB;
|
using ElementB = typename Config::Gemm::ElementB;
|
||||||
@@ -174,19 +174,20 @@ typename Config::Gemm::Arguments args_from_options(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Config>
|
template <typename Config>
|
||||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
torch::stable::Tensor const& B_sf,
|
||||||
cudaStream_t stream) {
|
torch::stable::Tensor const& alpha, int64_t m, int64_t n,
|
||||||
|
int64_t k, cudaStream_t stream) {
|
||||||
typename Config::Gemm gemm;
|
typename Config::Gemm gemm;
|
||||||
|
|
||||||
auto arguments =
|
auto arguments =
|
||||||
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||||
|
|
||||||
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
||||||
auto const workspace_options =
|
auto workspace =
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
std::nullopt, A.device());
|
||||||
|
|
||||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||||
|
|
||||||
@@ -197,12 +198,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
|||||||
|
|
||||||
// Dispatch function to select appropriate config based on M
|
// Dispatch function to select appropriate config based on M
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
|
||||||
torch::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
torch::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
torch::stable::Tensor const& B_sf,
|
||||||
int64_t k, cudaStream_t stream) {
|
torch::stable::Tensor const& alpha, int64_t m,
|
||||||
|
int64_t n, int64_t k, cudaStream_t stream) {
|
||||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||||
|
|
||||||
if (mp2 <= 16) {
|
if (mp2 <= 16) {
|
||||||
@@ -222,61 +224,65 @@ void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
|||||||
|
|
||||||
#else
|
#else
|
||||||
template <typename OutType>
|
template <typename OutType>
|
||||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_fp4_gemm_dispatch(torch::stable::Tensor& D,
|
||||||
torch::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
torch::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
torch::stable::Tensor const& B_sf,
|
||||||
int64_t k, cudaStream_t stream) {
|
torch::stable::Tensor const& alpha, int64_t m,
|
||||||
TORCH_CHECK(false,
|
int64_t n, int64_t k, cudaStream_t stream) {
|
||||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
STD_TORCH_CHECK(false,
|
||||||
"a CUTLASS 3.8 source directory to enable support.");
|
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||||
|
"a CUTLASS 3.8 source directory to enable support.");
|
||||||
}
|
}
|
||||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||||
|
|
||||||
#define CHECK_TYPE(x, st, m) \
|
#define CHECK_TYPE(x, st, m) \
|
||||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||||
|
": Inconsistency of torch::stable::Tensor type:", m)
|
||||||
#define CHECK_TH_CUDA(x, m) \
|
#define CHECK_TH_CUDA(x, m) \
|
||||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||||
#define CHECK_CONTIGUOUS(x, m) \
|
#define CHECK_CONTIGUOUS(x, m) \
|
||||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||||
#define CHECK_INPUT(x, st, m) \
|
#define CHECK_INPUT(x, st, m) \
|
||||||
CHECK_TH_CUDA(x, m); \
|
CHECK_TH_CUDA(x, m); \
|
||||||
CHECK_CONTIGUOUS(x, m); \
|
CHECK_CONTIGUOUS(x, m); \
|
||||||
CHECK_TYPE(x, st, m)
|
CHECK_TYPE(x, st, m)
|
||||||
|
|
||||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||||
|
|
||||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_scaled_fp4_mm_sm100a(torch::stable::Tensor& D,
|
||||||
torch::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
torch::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha) {
|
torch::stable::Tensor const& B_sf,
|
||||||
|
torch::stable::Tensor const& alpha) {
|
||||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||||
|
|
||||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||||
|
|
||||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
|
||||||
|
|
||||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
|
STD_TORCH_CHECK(A.size(1) == B.size(1),
|
||||||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
|
"a and b shapes cannot be multiplied (", A.size(0), "x",
|
||||||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
|
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
|
||||||
|
|
||||||
auto const m = A.sizes()[0];
|
auto const m = A.size(0);
|
||||||
auto const n = B.sizes()[0];
|
auto const n = B.size(0);
|
||||||
auto const k = A.sizes()[1] * 2;
|
auto const k = A.size(1) * 2;
|
||||||
|
|
||||||
constexpr int alignment = 32;
|
constexpr int alignment = 32;
|
||||||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
|
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
|
||||||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
|
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
|
||||||
"), k: ", k, ".");
|
"), k: ", k, ".");
|
||||||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
|
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
|
||||||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
|
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
|
||||||
|
").");
|
||||||
|
|
||||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||||
int rounded_m = round_up(m, 128);
|
int rounded_m = round_up(m, 128);
|
||||||
@@ -285,33 +291,34 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
|||||||
// integer.
|
// integer.
|
||||||
int rounded_k = round_up(k / 16, 4);
|
int rounded_k = round_up(k / 16, 4);
|
||||||
|
|
||||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
|
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
|
||||||
"scale_a and scale_b shapes cannot be multiplied (",
|
"scale_a and scale_b shapes cannot be multiplied (",
|
||||||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
|
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
|
||||||
"x", B_sf.sizes()[1], ")");
|
B_sf.size(1), ")");
|
||||||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
|
||||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
|
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
|
||||||
A_sf.sizes()[1], ")");
|
A_sf.size(1), ")");
|
||||||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
|
||||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
|
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
|
||||||
B_sf.sizes()[1], ")");
|
B_sf.size(1), ")");
|
||||||
|
|
||||||
auto out_dtype = D.dtype();
|
auto out_dtype = D.scalar_type();
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
A.get_device_index());
|
||||||
|
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
|
||||||
|
|
||||||
if (out_dtype == at::ScalarType::Half) {
|
if (out_dtype == torch::headeronly::ScalarType::Half) {
|
||||||
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
||||||
k, stream);
|
k, stream);
|
||||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
} else if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||||
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
||||||
m, n, k, stream);
|
m, n, k, stream);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
|
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (",
|
||||||
")");
|
out_dtype, ")");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -14,10 +14,9 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include "libtorch_stable/torch_utils.h"
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include "cutlass_extensions/common.hpp"
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
|
||||||
@@ -34,19 +33,20 @@
|
|||||||
|
|
||||||
using namespace cute;
|
using namespace cute;
|
||||||
|
|
||||||
#define CHECK_TYPE(x, st, m) \
|
#define CHECK_TYPE(x, st, m) \
|
||||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
STD_TORCH_CHECK(x.scalar_type() == st, \
|
||||||
|
": Inconsistency of torch::stable::Tensor type:", m)
|
||||||
#define CHECK_TH_CUDA(x, m) \
|
#define CHECK_TH_CUDA(x, m) \
|
||||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
STD_TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||||
#define CHECK_CONTIGUOUS(x, m) \
|
#define CHECK_CONTIGUOUS(x, m) \
|
||||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
STD_TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||||
#define CHECK_INPUT(x, st, m) \
|
#define CHECK_INPUT(x, st, m) \
|
||||||
CHECK_TH_CUDA(x, m); \
|
CHECK_TH_CUDA(x, m); \
|
||||||
CHECK_CONTIGUOUS(x, m); \
|
CHECK_CONTIGUOUS(x, m); \
|
||||||
CHECK_TYPE(x, st, m)
|
CHECK_TYPE(x, st, m)
|
||||||
|
|
||||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
|
||||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;
|
||||||
|
|
||||||
struct sm120_fp4_config_M256 {
|
struct sm120_fp4_config_M256 {
|
||||||
using ClusterShape = Shape<_1, _1, _1>;
|
using ClusterShape = Shape<_1, _1, _1>;
|
||||||
@@ -109,12 +109,13 @@ struct Fp4GemmSm120 {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
|
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
|
||||||
at::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
at::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
at::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha, int M,
|
torch::stable::Tensor const& B_sf,
|
||||||
int N, int K) {
|
torch::stable::Tensor const& alpha,
|
||||||
|
int M, int N, int K) {
|
||||||
using ElementA = typename Gemm::ElementA;
|
using ElementA = typename Gemm::ElementA;
|
||||||
using ElementB = typename Gemm::ElementB;
|
using ElementB = typename Gemm::ElementB;
|
||||||
using ElementD = typename Gemm::ElementD;
|
using ElementD = typename Gemm::ElementD;
|
||||||
@@ -158,18 +159,19 @@ typename Gemm::Arguments args_from_options(at::Tensor& D, at::Tensor const& A,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Gemm>
|
template <typename Gemm>
|
||||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
|
||||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha, int M, int N, int K,
|
torch::stable::Tensor const& B_sf,
|
||||||
|
torch::stable::Tensor const& alpha, int M, int N, int K,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
Gemm gemm;
|
Gemm gemm;
|
||||||
|
|
||||||
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
|
auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
|
||||||
|
|
||||||
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
||||||
auto const workspace_options =
|
auto workspace =
|
||||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
std::nullopt, A.device());
|
||||||
|
|
||||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||||
|
|
||||||
@@ -178,12 +180,13 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
|||||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
|
||||||
torch::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
torch::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha, int m, int n,
|
torch::stable::Tensor const& B_sf,
|
||||||
int k, cudaStream_t stream) {
|
torch::stable::Tensor const& alpha, int m,
|
||||||
|
int n, int k, cudaStream_t stream) {
|
||||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||||
if (mp2 <= 256) {
|
if (mp2 <= 256) {
|
||||||
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
|
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
|
||||||
@@ -194,12 +197,13 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
|
||||||
torch::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
torch::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha, int m, int n,
|
torch::stable::Tensor const& B_sf,
|
||||||
int k, cudaStream_t stream) {
|
torch::stable::Tensor const& alpha, int m,
|
||||||
|
int n, int k, cudaStream_t stream) {
|
||||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||||
if (mp2 <= 256) {
|
if (mp2 <= 256) {
|
||||||
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
|
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
|
||||||
@@ -210,11 +214,12 @@ void cutlass_fp4_f16_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_scaled_fp4_mm_sm120a(torch::stable::Tensor& D,
|
||||||
torch::Tensor const& B,
|
torch::stable::Tensor const& A,
|
||||||
torch::Tensor const& A_sf,
|
torch::stable::Tensor const& B,
|
||||||
torch::Tensor const& B_sf,
|
torch::stable::Tensor const& A_sf,
|
||||||
torch::Tensor const& alpha) {
|
torch::stable::Tensor const& B_sf,
|
||||||
|
torch::stable::Tensor const& alpha) {
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||||
@@ -222,24 +227,25 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
|||||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||||
|
|
||||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
CHECK_INPUT(alpha, torch::headeronly::ScalarType::Float, "alpha");
|
||||||
|
|
||||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
STD_TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
STD_TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
|
STD_TORCH_CHECK(A.size(1) == B.size(1),
|
||||||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
|
"a and b shapes cannot be multiplied (", A.size(0), "x",
|
||||||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
|
A.size(1), " and ", B.size(0), "x", B.size(1), ")");
|
||||||
|
|
||||||
auto const m = A.sizes()[0];
|
auto const m = A.size(0);
|
||||||
auto const n = B.sizes()[0];
|
auto const n = B.size(0);
|
||||||
auto const k = A.sizes()[1] * 2;
|
auto const k = A.size(1) * 2;
|
||||||
|
|
||||||
constexpr int alignment = 32;
|
constexpr int alignment = 32;
|
||||||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
|
STD_TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ",
|
||||||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
|
alignment, ", but got a shape: (", A.size(0), "x", A.size(1),
|
||||||
"), k: ", k, ".");
|
"), k: ", k, ".");
|
||||||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
|
STD_TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ",
|
||||||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
|
alignment, ", but got b shape: (", B.size(0), "x", B.size(1),
|
||||||
|
").");
|
||||||
|
|
||||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||||
int rounded_m = round_up(m, 128);
|
int rounded_m = round_up(m, 128);
|
||||||
@@ -248,38 +254,39 @@ void cutlass_scaled_fp4_mm_sm120a(torch::Tensor& D, torch::Tensor const& A,
|
|||||||
// integer.
|
// integer.
|
||||||
int rounded_k = round_up(k / 16, 4);
|
int rounded_k = round_up(k / 16, 4);
|
||||||
|
|
||||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
STD_TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
STD_TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
|
STD_TORCH_CHECK(A_sf.size(1) == B_sf.size(1),
|
||||||
"scale_a and scale_b shapes cannot be multiplied (",
|
"scale_a and scale_b shapes cannot be multiplied (",
|
||||||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
|
A_sf.size(0), "x", A_sf.size(1), " and ", B_sf.size(0), "x",
|
||||||
"x", B_sf.sizes()[1], ")");
|
B_sf.size(1), ")");
|
||||||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
STD_TORCH_CHECK(A_sf.size(0) == rounded_m && A_sf.size(1) == rounded_k,
|
||||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
|
"x", rounded_k, "), but got a shape (", A_sf.size(0), "x",
|
||||||
A_sf.sizes()[1], ")");
|
A_sf.size(1), ")");
|
||||||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
STD_TORCH_CHECK(B_sf.size(0) == rounded_n && B_sf.size(1) == rounded_k,
|
||||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
|
"x", rounded_k, "), but got a shape (", B_sf.size(0), "x",
|
||||||
B_sf.sizes()[1], ")");
|
B_sf.size(1), ")");
|
||||||
|
|
||||||
auto out_dtype = D.dtype();
|
auto out_dtype = D.scalar_type();
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
A.get_device_index());
|
||||||
|
const cudaStream_t stream = get_current_cuda_stream(A.get_device_index());
|
||||||
|
|
||||||
if (out_dtype == at::ScalarType::BFloat16) {
|
if (out_dtype == torch::headeronly::ScalarType::BFloat16) {
|
||||||
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
return cutlass_fp4_bf16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
||||||
stream);
|
stream);
|
||||||
} else if (out_dtype == at::ScalarType::Half) {
|
} else if (out_dtype == torch::headeronly::ScalarType::Half) {
|
||||||
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
return cutlass_fp4_f16_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, k,
|
||||||
stream);
|
stream);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
|
STD_TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm sm120 (",
|
||||||
out_dtype, ")");
|
out_dtype, ")");
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
TORCH_CHECK(false,
|
STD_TORCH_CHECK(false,
|
||||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||||
"a CUTLASS 3.8 source directory to enable support.");
|
"a CUTLASS 3.8 source directory to enable support.");
|
||||||
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
#endif // defined(CUTLASS_ARCH_MMA_SM120_SUPPORTED)
|
||||||
}
|
}
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
#include <cuda_fp8.h>
|
#include <cuda_fp8.h>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "../../cuda_vec_utils.cuh"
|
#include "cuda_vec_utils.cuh"
|
||||||
|
|
||||||
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
|
#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
|
||||||
CUDA_VERSION >= 12090
|
CUDA_VERSION >= 12090
|
||||||
@@ -103,6 +103,102 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
|||||||
ops.def(
|
ops.def(
|
||||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||||
"bool");
|
"bool");
|
||||||
|
|
||||||
|
// CUTLASS nvfp4 block scaled GEMM
|
||||||
|
ops.def(
|
||||||
|
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||||
|
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||||
|
" Tensor alpha) -> ()");
|
||||||
|
|
||||||
|
// cutlass nvfp4 block scaled group GEMM
|
||||||
|
ops.def(
|
||||||
|
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||||
|
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
|
||||||
|
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||||
|
|
||||||
|
// Compute NVFP4 block quantized tensor.
|
||||||
|
ops.def(
|
||||||
|
"scaled_fp4_quant(Tensor input,"
|
||||||
|
" Tensor input_scale, bool "
|
||||||
|
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
||||||
|
|
||||||
|
// Out variant
|
||||||
|
// TODO: Add out_variant tag once PyTorch supports it (added in 2.11)
|
||||||
|
// This registration is now migrated to stable ABI
|
||||||
|
// at::Tag::out_variant is not available in the stable ABI (enum_tag.h is not
|
||||||
|
// yet in torch/headeronly), the tag should be applied from Python
|
||||||
|
// via torch.library.Library.define(..., tags=(torch.Tag.out_variant,))
|
||||||
|
// with the .impl remaining in C++.
|
||||||
|
// See pytorch/pytorch#176117.
|
||||||
|
ops.def(
|
||||||
|
"scaled_fp4_quant.out(Tensor input,"
|
||||||
|
" Tensor input_scale, bool "
|
||||||
|
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
||||||
|
"-> ()");
|
||||||
|
|
||||||
|
// Compute NVFP4 experts quantization.
|
||||||
|
ops.def(
|
||||||
|
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||||
|
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||||
|
"Tensor output_scale_offset_by_experts) -> ()");
|
||||||
|
|
||||||
|
// Fused SiLU+Mul+NVFP4 experts quantization.
|
||||||
|
ops.def(
|
||||||
|
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
|
||||||
|
"output_scale,"
|
||||||
|
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||||
|
"Tensor output_scale_offset_by_experts) -> ()");
|
||||||
|
|
||||||
|
// Fused SiLU+Mul+NVFP4 quantization.
|
||||||
|
ops.def(
|
||||||
|
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
|
||||||
|
"Tensor input, Tensor input_global_scale) -> ()");
|
||||||
|
|
||||||
|
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
|
||||||
|
// of the given capability
|
||||||
|
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
|
||||||
|
|
||||||
|
// CUTLASS w4a8 GEMM
|
||||||
|
ops.def(
|
||||||
|
"cutlass_w4a8_mm("
|
||||||
|
" Tensor A,"
|
||||||
|
" Tensor B,"
|
||||||
|
" Tensor group_scales,"
|
||||||
|
" int group_size,"
|
||||||
|
" Tensor channel_scales,"
|
||||||
|
" Tensor token_scales,"
|
||||||
|
" ScalarType? out_type,"
|
||||||
|
" str? maybe_schedule"
|
||||||
|
") -> Tensor");
|
||||||
|
|
||||||
|
// pack scales
|
||||||
|
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||||
|
|
||||||
|
// encode and reorder weight matrix
|
||||||
|
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||||
|
|
||||||
|
// CUTLASS w4a8 grouped GEMM
|
||||||
|
ops.def(
|
||||||
|
"cutlass_w4a8_moe_mm("
|
||||||
|
" Tensor! out_tensors,"
|
||||||
|
" Tensor a_tensors,"
|
||||||
|
" Tensor b_tensors,"
|
||||||
|
" Tensor a_scales,"
|
||||||
|
" Tensor b_scales,"
|
||||||
|
" Tensor b_group_scales,"
|
||||||
|
" int b_group_size,"
|
||||||
|
" Tensor expert_offsets,"
|
||||||
|
" Tensor problem_sizes,"
|
||||||
|
" Tensor a_strides,"
|
||||||
|
" Tensor b_strides,"
|
||||||
|
" Tensor c_strides,"
|
||||||
|
" Tensor group_scale_strides,"
|
||||||
|
" str? maybe_schedule"
|
||||||
|
") -> ()");
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
|
||||||
|
"Tensor)");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,6 +224,18 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
|||||||
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
|
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
|
||||||
ops.impl("get_cutlass_batched_moe_mm_data",
|
ops.impl("get_cutlass_batched_moe_mm_data",
|
||||||
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
|
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
|
||||||
|
|
||||||
|
// FP4/NVFP4 ops
|
||||||
|
ops.impl("cutlass_scaled_fp4_mm", TORCH_BOX(&cutlass_scaled_fp4_mm));
|
||||||
|
ops.impl("scaled_fp4_quant", TORCH_BOX(&scaled_fp4_quant_func));
|
||||||
|
ops.impl("scaled_fp4_quant.out", TORCH_BOX(&scaled_fp4_quant_out));
|
||||||
|
ops.impl("scaled_fp4_experts_quant", TORCH_BOX(&scaled_fp4_experts_quant));
|
||||||
|
ops.impl("silu_and_mul_scaled_fp4_experts_quant",
|
||||||
|
TORCH_BOX(&silu_and_mul_scaled_fp4_experts_quant));
|
||||||
|
ops.impl("silu_and_mul_nvfp4_quant", TORCH_BOX(&silu_and_mul_nvfp4_quant));
|
||||||
|
|
||||||
|
// W4A8 ops: impl registrations are in the source files
|
||||||
|
// (w4a8_mm_entry.cu and w4a8_grouped_mm_entry.cu)
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,6 +251,8 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
|||||||
TORCH_BOX(&cutlass_group_gemm_supported));
|
TORCH_BOX(&cutlass_group_gemm_supported));
|
||||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||||
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
|
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
|
||||||
|
ops.impl("cutlass_scaled_mm_supports_fp4",
|
||||||
|
TORCH_BOX(&cutlass_scaled_mm_supports_fp4));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||||
#include <torch/csrc/stable/accelerator.h>
|
#include <torch/csrc/stable/accelerator.h>
|
||||||
|
#include <torch/csrc/stable/ops.h>
|
||||||
#include <torch/csrc/stable/tensor.h>
|
#include <torch/csrc/stable/tensor.h>
|
||||||
#include <torch/headeronly/util/shim_utils.h>
|
#include <torch/headeronly/util/shim_utils.h>
|
||||||
|
|
||||||
|
|||||||
44
csrc/ops.h
44
csrc/ops.h
@@ -152,12 +152,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
|||||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
void silu_and_mul_nvfp4_quant(torch::Tensor& out,
|
|
||||||
torch::Tensor& output_block_scale,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& input_global_scale);
|
|
||||||
#endif
|
|
||||||
void persistent_masked_m_silu_mul_quant(
|
void persistent_masked_m_silu_mul_quant(
|
||||||
const at::Tensor& input, // (E, T, 2*H)
|
const at::Tensor& input, // (E, T, 2*H)
|
||||||
const at::Tensor& counts, // (E)
|
const at::Tensor& counts, // (E)
|
||||||
@@ -225,44 +219,6 @@ torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
|
|||||||
|
|
||||||
int64_t ggml_moe_get_block_size(int64_t type);
|
int64_t ggml_moe_get_block_size(int64_t type);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
|
||||||
|
|
||||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
|
||||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
|
||||||
torch::Tensor const& B_sf,
|
|
||||||
torch::Tensor const& alpha);
|
|
||||||
|
|
||||||
void cutlass_fp4_group_mm(
|
|
||||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
|
||||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
|
||||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
|
||||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
|
|
||||||
|
|
||||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_scale,
|
|
||||||
bool is_sf_swizzled_layout);
|
|
||||||
|
|
||||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
|
||||||
torch::Tensor const& input_scale,
|
|
||||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
|
||||||
torch::Tensor& output_scale);
|
|
||||||
|
|
||||||
void scaled_fp4_experts_quant(
|
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
||||||
torch::Tensor const& input_offset_by_experts,
|
|
||||||
torch::Tensor const& output_scale_offset_by_experts);
|
|
||||||
|
|
||||||
void silu_and_mul_scaled_fp4_experts_quant(
|
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
||||||
torch::Tensor const& input_offset_by_experts,
|
|
||||||
torch::Tensor const& output_scale_offset_by_experts);
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor const& scale,
|
torch::Tensor const& scale,
|
||||||
std::optional<torch::Tensor> const& azp);
|
std::optional<torch::Tensor> const& azp);
|
||||||
|
|||||||
@@ -1,163 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
#include "cutlass_extensions/common.hpp"
|
|
||||||
#include "nvfp4_utils.cuh"
|
|
||||||
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
|
||||||
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
|
|
||||||
torch::Tensor const& input,
|
|
||||||
torch::Tensor const& output_sf,
|
|
||||||
torch::Tensor const& input_sf,
|
|
||||||
bool is_sf_swizzled_layout);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
|
||||||
void scaled_fp4_experts_quant_sm1xxa(
|
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
||||||
torch::Tensor const& input_offset_by_experts,
|
|
||||||
torch::Tensor const& output_scale_offset_by_experts);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
|
||||||
void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
|
|
||||||
torch::Tensor& output_sf,
|
|
||||||
torch::Tensor& input,
|
|
||||||
torch::Tensor& input_sf);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
|
||||||
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
||||||
torch::Tensor const& input_offset_by_experts,
|
|
||||||
torch::Tensor const& output_scale_offset_by_experts);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
static bool nvfp4_quant_sm_supported() {
|
|
||||||
const int32_t sm = get_sm_version_num();
|
|
||||||
#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100
|
|
||||||
if (sm >= 100 && sm < 120) return true;
|
|
||||||
#endif
|
|
||||||
#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120
|
|
||||||
if (sm >= 120 && sm < 130) return true;
|
|
||||||
#endif
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void scaled_fp4_quant_out(torch::Tensor const& input,
|
|
||||||
torch::Tensor const& input_sf,
|
|
||||||
bool is_sf_swizzled_layout, torch::Tensor& output,
|
|
||||||
torch::Tensor& output_sf) {
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
|
||||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
|
||||||
"No compiled nvfp4 quantization kernel for SM ",
|
|
||||||
get_sm_version_num(),
|
|
||||||
". Recompile with the appropriate CUDA arch.");
|
|
||||||
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
|
|
||||||
is_sf_swizzled_layout);
|
|
||||||
#endif
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_sf,
|
|
||||||
bool is_sf_swizzled_layout) {
|
|
||||||
int64_t n = input.size(-1);
|
|
||||||
int64_t m = input.numel() / n;
|
|
||||||
auto device = input.device();
|
|
||||||
|
|
||||||
// Two fp4 values packed into a uint8
|
|
||||||
auto output = torch::empty(
|
|
||||||
{m, n / 2}, torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
|
||||||
|
|
||||||
torch::Tensor output_sf;
|
|
||||||
if (is_sf_swizzled_layout) {
|
|
||||||
auto [sf_m, sf_n] = vllm::computeSwizzledSFShape(m, n);
|
|
||||||
output_sf = torch::empty(
|
|
||||||
{sf_m, sf_n},
|
|
||||||
torch::TensorOptions().device(device).dtype(torch::kInt32));
|
|
||||||
} else {
|
|
||||||
output_sf = torch::empty(
|
|
||||||
{m, n / CVT_FP4_SF_VEC_SIZE},
|
|
||||||
torch::TensorOptions().device(device).dtype(torch::kUInt8));
|
|
||||||
}
|
|
||||||
|
|
||||||
scaled_fp4_quant_out(input, input_sf, is_sf_swizzled_layout, output,
|
|
||||||
output_sf);
|
|
||||||
return {output, output_sf};
|
|
||||||
}
|
|
||||||
|
|
||||||
void scaled_fp4_experts_quant(
|
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
||||||
torch::Tensor const& input_offset_by_experts,
|
|
||||||
torch::Tensor const& output_scale_offset_by_experts) {
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
|
||||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
|
||||||
"No compiled nvfp4 experts quantization kernel for SM ",
|
|
||||||
get_sm_version_num(),
|
|
||||||
". Recompile with the appropriate CUDA arch.");
|
|
||||||
return scaled_fp4_experts_quant_sm1xxa(
|
|
||||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
|
||||||
output_scale_offset_by_experts);
|
|
||||||
#endif
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
|
||||||
"No compiled nvfp4 experts quantization kernel");
|
|
||||||
}
|
|
||||||
|
|
||||||
void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
|
|
||||||
torch::Tensor& input, torch::Tensor& input_sf) {
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
|
||||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
|
||||||
"No compiled silu_and_mul nvfp4 quantization kernel for SM ",
|
|
||||||
get_sm_version_num(),
|
|
||||||
". Recompile with the appropriate CUDA arch.");
|
|
||||||
return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf);
|
|
||||||
#endif
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "No compiled silu_and_mul nvfp4 quantization kernel");
|
|
||||||
}
|
|
||||||
|
|
||||||
void silu_and_mul_scaled_fp4_experts_quant(
|
|
||||||
torch::Tensor& output, torch::Tensor& output_scale,
|
|
||||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
|
||||||
torch::Tensor const& input_offset_by_experts,
|
|
||||||
torch::Tensor const& output_scale_offset_by_experts) {
|
|
||||||
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
|
|
||||||
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
|
|
||||||
TORCH_CHECK(nvfp4_quant_sm_supported(),
|
|
||||||
"No compiled silu_and_mul nvfp4 experts quantization kernel "
|
|
||||||
"for SM ",
|
|
||||||
get_sm_version_num(),
|
|
||||||
". Recompile with the appropriate CUDA arch.");
|
|
||||||
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
|
|
||||||
output, output_scale, input, input_global_scale, input_offset_by_experts,
|
|
||||||
output_scale_offset_by_experts);
|
|
||||||
#endif
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
|
|
||||||
}
|
|
||||||
@@ -109,13 +109,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
ops.def(
|
|
||||||
"silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, "
|
|
||||||
"Tensor input, Tensor input_global_scale) -> ()");
|
|
||||||
ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||||
|
|
||||||
@@ -332,47 +325,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
|
"Tensor? qzeros_or_none, bool inplace) -> Tensor");
|
||||||
// conditionally compiled so impl registrations are in source file
|
// conditionally compiled so impl registrations are in source file
|
||||||
|
|
||||||
// CUTLASS w4a8 GEMM
|
|
||||||
ops.def(
|
|
||||||
"cutlass_w4a8_mm("
|
|
||||||
" Tensor A,"
|
|
||||||
" Tensor B,"
|
|
||||||
" Tensor group_scales,"
|
|
||||||
" int group_size,"
|
|
||||||
" Tensor channel_scales,"
|
|
||||||
" Tensor token_scales,"
|
|
||||||
" ScalarType? out_type,"
|
|
||||||
" str? maybe_schedule"
|
|
||||||
") -> Tensor");
|
|
||||||
// pack scales
|
|
||||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
|
||||||
// encode and reorder weight matrix
|
|
||||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
|
||||||
// conditionally compiled so impl registration is in source file
|
|
||||||
|
|
||||||
// CUTLASS w4a8 grouped GEMM
|
|
||||||
ops.def(
|
|
||||||
"cutlass_w4a8_moe_mm("
|
|
||||||
" Tensor! out_tensors,"
|
|
||||||
" Tensor a_tensors,"
|
|
||||||
" Tensor b_tensors,"
|
|
||||||
" Tensor a_scales,"
|
|
||||||
" Tensor b_scales,"
|
|
||||||
" Tensor b_group_scales,"
|
|
||||||
" int b_group_size,"
|
|
||||||
" Tensor expert_offsets,"
|
|
||||||
" Tensor problem_sizes,"
|
|
||||||
" Tensor a_strides,"
|
|
||||||
" Tensor b_strides,"
|
|
||||||
" Tensor c_strides,"
|
|
||||||
" Tensor group_scale_strides,"
|
|
||||||
" str? maybe_schedule"
|
|
||||||
") -> ()");
|
|
||||||
ops.def(
|
|
||||||
"cutlass_encode_and_reorder_int4b_grouped(Tensor b_tensors) -> (Tensor, "
|
|
||||||
"Tensor)");
|
|
||||||
// conditionally compiled so impl registration is in source file
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Dequantization for GGML.
|
// Dequantization for GGML.
|
||||||
@@ -409,20 +361,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// CUTLASS nvfp4 block scaled GEMM
|
|
||||||
ops.def(
|
|
||||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
|
||||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
|
||||||
" Tensor alpha) -> ()");
|
|
||||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
|
||||||
|
|
||||||
// cutlass nvfp4 block scaled group GEMM
|
|
||||||
ops.def(
|
|
||||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
|
||||||
" Tensor a_blockscale, Tensor b_blockscales, Tensor alphas,"
|
|
||||||
" Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
|
||||||
// conditionally compiled so impl registration is in source file
|
|
||||||
|
|
||||||
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
|
// Expert-specialization mxfp8 blockscaled grouped quantization (SM100+).
|
||||||
ops.def(
|
ops.def(
|
||||||
"mxfp8_experts_quant("
|
"mxfp8_experts_quant("
|
||||||
@@ -455,44 +393,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"-> int");
|
"-> int");
|
||||||
// conditionally compiled so impl in source file
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
// Compute NVFP4 block quantized tensor.
|
|
||||||
ops.def(
|
|
||||||
"scaled_fp4_quant(Tensor input,"
|
|
||||||
" Tensor input_scale, bool "
|
|
||||||
"is_sf_swizzled_layout) -> (Tensor, Tensor)");
|
|
||||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant_func);
|
|
||||||
|
|
||||||
// Out variant
|
|
||||||
// TODO: Add {at::Tag::out_variant} tag and update all call sites
|
|
||||||
// to use the functional variant once vLLM upgrades PyTorch.
|
|
||||||
// See pytorch/pytorch#176117.
|
|
||||||
ops.def(
|
|
||||||
"scaled_fp4_quant.out(Tensor input,"
|
|
||||||
" Tensor input_scale, bool "
|
|
||||||
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
|
|
||||||
"-> ()");
|
|
||||||
ops.impl("scaled_fp4_quant.out", torch::kCUDA, &scaled_fp4_quant_out);
|
|
||||||
|
|
||||||
// Compute NVFP4 experts quantization.
|
|
||||||
ops.def(
|
|
||||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
|
||||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
|
||||||
"Tensor output_scale_offset_by_experts) -> ()");
|
|
||||||
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
|
|
||||||
|
|
||||||
// Fused SiLU+Mul+NVFP4 experts quantization.
|
|
||||||
ops.def(
|
|
||||||
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
|
|
||||||
"output_scale,"
|
|
||||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
|
||||||
"Tensor output_scale_offset_by_experts) -> ()");
|
|
||||||
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
|
|
||||||
&silu_and_mul_scaled_fp4_experts_quant);
|
|
||||||
|
|
||||||
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
|
|
||||||
// of the given capability
|
|
||||||
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
|
|
||||||
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Quantized GEMM for GPTQ.
|
// Quantized GEMM for GPTQ.
|
||||||
|
|||||||
Reference in New Issue
Block a user