|
|
|
|
@@ -22,6 +22,7 @@
|
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
|
#include <c10/cuda/CUDAStream.h>
|
|
|
|
|
#include "cutlass_extensions/common.hpp"
|
|
|
|
|
|
|
|
|
|
#include "cute/tensor.hpp"
|
|
|
|
|
#include "cutlass/tensor_ref.h"
|
|
|
|
|
@@ -173,7 +174,7 @@ void run_get_group_gemm_starts(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename OutType>
|
|
|
|
|
void run_fp4_blockwise_scaled_group_mm(
|
|
|
|
|
void run_fp4_blockwise_scaled_group_mm_sm100(
|
|
|
|
|
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
|
|
|
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
|
|
|
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
|
|
|
|
@@ -343,17 +344,225 @@ void run_fp4_blockwise_scaled_group_mm(
|
|
|
|
|
|
|
|
|
|
auto can_implement_status = gemm_op.can_implement(args);
|
|
|
|
|
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
|
|
|
|
"Failed to implement GEMM");
|
|
|
|
|
"Failed to implement GEMM: status=", (int)can_implement_status);
|
|
|
|
|
|
|
|
|
|
// Run the GEMM
|
|
|
|
|
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
|
|
|
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
|
|
|
|
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
|
|
|
|
"Failed to initialize GEMM: status=", (int)status,
|
|
|
|
|
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
|
|
|
|
" M=", M, " N=", N, " K=", K);
|
|
|
|
|
|
|
|
|
|
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
|
|
|
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void run_fp4_blockwise_scaled_group_mm_sm120(
|
|
|
|
|
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
|
|
|
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
|
|
|
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
|
|
|
|
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
|
|
|
|
int N, int K) {
|
|
|
|
|
using ProblemShape =
|
|
|
|
|
cutlass::gemm::GroupProblemShape<Shape<int32_t, int32_t, int32_t>>;
|
|
|
|
|
using ElementType = cutlass::float_e2m1_t;
|
|
|
|
|
using ElementSFType = cutlass::float_ue4m3_t;
|
|
|
|
|
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
|
|
|
|
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
|
|
|
|
|
|
|
|
|
// NOTE: For SM120 it seems templating the output type is not supported and
|
|
|
|
|
// we need to hardcode the output type to bfloat16
|
|
|
|
|
using ElementC = cutlass::bfloat16_t;
|
|
|
|
|
using ElementD = ElementC;
|
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
// Layout definitions
|
|
|
|
|
using LayoutA = cutlass::layout::RowMajor;
|
|
|
|
|
using LayoutB = cutlass::layout::ColumnMajor;
|
|
|
|
|
using LayoutC = cutlass::layout::RowMajor;
|
|
|
|
|
using LayoutD = LayoutC;
|
|
|
|
|
|
|
|
|
|
// Alignment constraints
|
|
|
|
|
static constexpr int AlignmentA = 32;
|
|
|
|
|
static constexpr int AlignmentB = 32;
|
|
|
|
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
|
|
|
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
|
|
|
|
|
|
|
|
|
// Architecture definitions
|
|
|
|
|
using ArchTag = cutlass::arch::Sm120;
|
|
|
|
|
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
|
|
|
|
|
|
|
|
|
using ClusterShape = Shape<_1, _1, _1>;
|
|
|
|
|
using MmaTileShape = Shape<_128, _128, _128>;
|
|
|
|
|
|
|
|
|
|
using FusionOperation = cutlass::epilogue::fusion::LinearCombination<
|
|
|
|
|
ElementD, ElementAccumulator, ElementC, ElementAccumulator>;
|
|
|
|
|
|
|
|
|
|
using CollectiveEpilogue =
|
|
|
|
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
|
|
|
|
ArchTag, OperatorClass, MmaTileShape, ClusterShape,
|
|
|
|
|
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
|
|
|
|
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
|
|
|
|
LayoutD*, AlignmentD,
|
|
|
|
|
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
|
|
|
|
FusionOperation>::CollectiveOp;
|
|
|
|
|
|
|
|
|
|
using CollectiveMainloop =
|
|
|
|
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
|
|
|
|
ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB,
|
|
|
|
|
LayoutB*, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape,
|
|
|
|
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
|
|
|
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
|
|
|
|
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
|
|
|
|
|
|
|
|
|
|
using GemmKernel =
|
|
|
|
|
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
|
|
|
|
CollectiveEpilogue>;
|
|
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
|
|
|
|
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
|
|
|
|
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
|
|
|
|
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
|
|
|
|
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
|
|
|
|
|
|
|
|
|
using LayoutSFA =
|
|
|
|
|
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA;
|
|
|
|
|
using LayoutSFB =
|
|
|
|
|
typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB;
|
|
|
|
|
using ScaleConfig =
|
|
|
|
|
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
|
|
|
|
|
|
|
|
|
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
|
|
|
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
|
|
|
|
auto options_int =
|
|
|
|
|
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
|
|
|
|
|
|
|
|
|
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
|
|
|
|
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
|
|
|
|
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
|
|
|
|
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
|
|
|
|
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
|
|
|
|
torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int);
|
|
|
|
|
torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int);
|
|
|
|
|
torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int);
|
|
|
|
|
torch::Tensor c_strides1 =
|
|
|
|
|
torch::full({num_experts}, output.stride(0), options_int);
|
|
|
|
|
torch::Tensor a_strides1 =
|
|
|
|
|
torch::full({num_experts}, a.stride(0) * 2, options_int);
|
|
|
|
|
torch::Tensor b_strides1 =
|
|
|
|
|
torch::full({num_experts}, b.stride(1) * 2, options_int);
|
|
|
|
|
|
|
|
|
|
run_get_group_gemm_starts<LayoutSFA, LayoutSFB, ScaleConfig>(
|
|
|
|
|
a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs,
|
|
|
|
|
layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas,
|
|
|
|
|
expert_offsets, sf_offsets, problem_sizes, M, N, K);
|
|
|
|
|
|
|
|
|
|
// Create an instance of the GEMM
|
|
|
|
|
Gemm gemm_op;
|
|
|
|
|
|
|
|
|
|
// Initialize problem_sizes_as_shapes correctly
|
|
|
|
|
UnderlyingProblemShape* problem_sizes_as_shapes =
|
|
|
|
|
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
|
|
|
|
|
|
|
|
|
// Set the Scheduler info
|
|
|
|
|
cutlass::KernelHardwareInfo hw_info;
|
|
|
|
|
using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions;
|
|
|
|
|
typename Gemm::GemmKernel::TileSchedulerArguments scheduler;
|
|
|
|
|
scheduler.raster_order = RasterOrderOptions::AlongM;
|
|
|
|
|
hw_info.device_id = a.get_device();
|
|
|
|
|
static std::unordered_map<int, int> cached_sm_counts;
|
|
|
|
|
if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
|
|
|
|
|
cached_sm_counts[hw_info.device_id] =
|
|
|
|
|
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
|
|
|
|
hw_info.device_id);
|
|
|
|
|
}
|
|
|
|
|
hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
|
|
|
|
|
|
|
|
|
|
// Mainloop Arguments
|
|
|
|
|
typename GemmKernel::MainloopArguments mainloop_args{
|
|
|
|
|
static_cast<const ElementType**>(a_ptrs.data_ptr()),
|
|
|
|
|
static_cast<StrideA*>(a_strides1.data_ptr()),
|
|
|
|
|
static_cast<const ElementType**>(b_ptrs.data_ptr()),
|
|
|
|
|
static_cast<StrideB*>(b_strides1.data_ptr()),
|
|
|
|
|
static_cast<const ElementSFType**>(a_scales_ptrs.data_ptr()),
|
|
|
|
|
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()),
|
|
|
|
|
static_cast<const ElementSFType**>(b_scales_ptrs.data_ptr()),
|
|
|
|
|
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr())};
|
|
|
|
|
|
|
|
|
|
// Epilogue Arguments
|
|
|
|
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
|
|
|
|
{}, // epilogue.thread
|
|
|
|
|
nullptr,
|
|
|
|
|
static_cast<StrideC*>(c_strides1.data_ptr()),
|
|
|
|
|
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
|
|
|
|
static_cast<StrideC*>(c_strides1.data_ptr())};
|
|
|
|
|
auto& fusion_args = epilogue_args.thread;
|
|
|
|
|
fusion_args.alpha_ptr_array =
|
|
|
|
|
reinterpret_cast<float**>(alpha_ptrs.data_ptr());
|
|
|
|
|
fusion_args.dAlpha = {_0{}, _0{}, 1};
|
|
|
|
|
fusion_args.beta = 0.0f;
|
|
|
|
|
|
|
|
|
|
// Gemm Arguments
|
|
|
|
|
typename GemmKernel::Arguments args{
|
|
|
|
|
cutlass::gemm::GemmUniversalMode::kGrouped,
|
|
|
|
|
{num_experts, problem_sizes_as_shapes, nullptr},
|
|
|
|
|
mainloop_args,
|
|
|
|
|
epilogue_args,
|
|
|
|
|
hw_info,
|
|
|
|
|
scheduler};
|
|
|
|
|
|
|
|
|
|
size_t workspace_size = Gemm::get_workspace_size(args);
|
|
|
|
|
auto const workspace_options =
|
|
|
|
|
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
|
|
|
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
|
|
|
|
|
|
|
|
|
auto can_implement_status = gemm_op.can_implement(args);
|
|
|
|
|
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
|
|
|
|
"Failed to implement GEMM: status=", (int)can_implement_status);
|
|
|
|
|
|
|
|
|
|
// Run the GEMM
|
|
|
|
|
auto status = gemm_op.initialize(args, workspace.data_ptr());
|
|
|
|
|
TORCH_CHECK(status == cutlass::Status::kSuccess,
|
|
|
|
|
"Failed to initialize GEMM: status=", (int)status,
|
|
|
|
|
" workspace_size=", workspace_size, " num_experts=", num_experts,
|
|
|
|
|
" M=", M, " N=", N, " K=", K);
|
|
|
|
|
|
|
|
|
|
status = gemm_op.run(args, workspace.data_ptr(), stream);
|
|
|
|
|
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename OutType>
|
|
|
|
|
void run_fp4_blockwise_scaled_group_mm(
|
|
|
|
|
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
|
|
|
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
|
|
|
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
|
|
|
|
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M,
|
|
|
|
|
int N, int K) {
|
|
|
|
|
int32_t version_num = get_sm_version_num();
|
|
|
|
|
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
|
|
|
|
if (version_num >= 120 && version_num < 130) {
|
|
|
|
|
run_fp4_blockwise_scaled_group_mm_sm120(
|
|
|
|
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
|
|
|
|
expert_offsets, sf_offsets, M, N, K);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
|
|
|
|
if (version_num >= 100 && version_num < 120) {
|
|
|
|
|
run_fp4_blockwise_scaled_group_mm_sm100<OutType>(
|
|
|
|
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
|
|
|
|
expert_offsets, sf_offsets, M, N, K);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
|
|
|
false,
|
|
|
|
|
"No compiled cutlass_fp4_group_mm kernel for CUDA device capability: ",
|
|
|
|
|
version_num, ". Required capability: 100 or 120");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
|
|
|
|
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
|
|
|
|
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
|
|
|
|
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
|
|
|
|
#endif
|
|
|
|
|
@@ -374,7 +583,8 @@ void cutlass_fp4_group_mm(
|
|
|
|
|
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
|
|
|
|
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
|
|
|
|
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) {
|
|
|
|
|
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100
|
|
|
|
|
#if (defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100) || \
|
|
|
|
|
(defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120)
|
|
|
|
|
// Input validation
|
|
|
|
|
CHECK_INPUT(a, FLOAT4_E2M1X2, "a");
|
|
|
|
|
CHECK_INPUT(b, FLOAT4_E2M1X2, "b");
|
|
|
|
|
@@ -408,6 +618,14 @@ void cutlass_fp4_group_mm(
|
|
|
|
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
|
|
|
|
expert_offsets, sf_offsets, M, N, K);
|
|
|
|
|
} else {
|
|
|
|
|
#if defined ENABLE_NVFP4_SM120 && ENABLE_NVFP4_SM120
|
|
|
|
|
int32_t version_num = get_sm_version_num();
|
|
|
|
|
if (version_num >= 120 && version_num < 130) {
|
|
|
|
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
|
|
|
false, "SM120 NVFP4 MOE only supports bfloat16 output, got: ",
|
|
|
|
|
output.scalar_type());
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
run_fp4_blockwise_scaled_group_mm<cutlass::half_t>(
|
|
|
|
|
output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes,
|
|
|
|
|
expert_offsets, sf_offsets, M, N, K);
|
|
|
|
|
@@ -416,8 +634,8 @@ void cutlass_fp4_group_mm(
|
|
|
|
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
|
|
|
false,
|
|
|
|
|
"No compiled cutlass_fp4_group_mm kernel, vLLM must "
|
|
|
|
|
"be compiled with ENABLE_NVFP4_SM100 for SM100+ and CUDA "
|
|
|
|
|
"12.8 or above.");
|
|
|
|
|
"be compiled with ENABLE_NVFP4_SM100 or ENABLE_NVFP4_SM120 for SM100/120 "
|
|
|
|
|
"and CUDA 12.8 or above.");
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|