[3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
mikaylagawarecki
2026-03-30 14:20:13 -04:00
committed by GitHub
parent b5e608258e
commit ab1a6a43fa
54 changed files with 1842 additions and 1610 deletions

View File

@@ -6,14 +6,16 @@
#include <cstdio>
#include <cstdlib>
#include <torch/headeronly/util/shim_utils.h>
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
cutlassGetStatusString(error)); \
}
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {

View File

@@ -3,6 +3,14 @@
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
// (stable ABI) targets. When compiled under the stable ABI target,
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
// use torch::stable::Tensor instead.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#endif
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
@@ -15,6 +23,12 @@
namespace vllm::c3x {
#ifdef TORCH_TARGET_VERSION
using TensorType = torch::stable::Tensor;
#else
using TensorType = torch::Tensor;
#endif
using namespace cute;
template <typename T>
@@ -84,7 +98,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
static auto args_from_tensor(TensorType const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
@@ -100,7 +114,7 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
static auto args_from_tensor(std::optional<TensorType> const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
@@ -158,8 +172,8 @@ struct ScaledEpilogue
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
@@ -203,9 +217,9 @@ struct ScaledEpilogueBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& azp_adj,
std::optional<TensorType> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken
EVTComputeScaleB, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias) {
static ArgumentType prepare_args(TensorType const& a_scales,
TensorType const& b_scales,
TensorType const& azp_adj,
TensorType const& azp,
std::optional<TensorType> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/csrc/stable/tensor.h>
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
@@ -52,7 +54,7 @@ struct ScaledEpilogueBase {
// from a tensor. It can handle both row and column, as well as row/column or
// scalar cases.
template <typename Descriptor, typename T>
static auto args_from_tensor(torch::Tensor const& tensor) {
static auto args_from_tensor(torch::stable::Tensor const& tensor) {
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
@@ -68,7 +70,8 @@ struct ScaledEpilogueBase {
// This overload handles the case where there might not be a tensor, in which
// case a nullptr is passed and a constant (0) is used.
template <typename Descriptor, typename T>
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
static auto args_from_tensor(
std::optional<torch::stable::Tensor> const& tensor) {
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
using Arguments = typename Descriptor::Arguments;
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
@@ -117,8 +120,8 @@ struct ScaledEpilogue
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
@@ -160,9 +163,9 @@ struct ScaledEpilogueBias
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias) {
static ArgumentType prepare_args(
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias) {
static ArgumentType prepare_args(
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
std::optional<torch::stable::Tensor> const& bias) {
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);

View File

@@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
torch::stable::Tensor& output_s,
int64_t group_size, double eps, double int8_min,
double int8_max);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch);
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif

View File

@@ -2,9 +2,10 @@
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <ATen/cuda/CUDAContext.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass/cutlass.h"
@@ -25,14 +26,14 @@
namespace vllm::c3x {
static inline cute::Shape<int, int, int, int> get_problem_shape(
torch::Tensor const& a, torch::Tensor const& b) {
torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
return {m, n, k, 1};
}
template <typename GemmKernel>
void cutlass_gemm_caller(
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
@@ -50,19 +51,20 @@ void cutlass_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(device);
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, device);
auto stream = at::cuda::getCurrentCUDAStream(device.index());
auto stream = get_current_cuda_stream(device.index());
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_gemm_caller(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementC = typename Gemm::ElementC;

View File

@@ -4,13 +4,12 @@
namespace vllm {
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
if (azp) {
return cutlass_scaled_mm_sm90_int8_epilogue<
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,

View File

@@ -0,0 +1,22 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
}
template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());

View File

@@ -0,0 +1,22 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
}
template <typename OutType>
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
int M = a.size(0);
if (M <= 256) {
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;

View File

@@ -0,0 +1,23 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
StrideA a_stride;
StrideB b_stride;
@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
}
template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
// TODO: better heuristics
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>,

View File

@@ -1,52 +1,57 @@
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias,
void dispatch_scaled_mm(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
int32_t version_num = get_sm_version_num();
TORCH_CHECK(
STD_TORCH_CHECK(
false, "Int8 not supported on SM", version_num,
". Use FP8 quantization instead, or run on older arch (SM < 100).");
}
}
} else {
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
TORCH_CHECK(
STD_TORCH_CHECK(
a.size(0) == a_scales.size(0) &&
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
"a_scale_group_shape must be [1, 128].");
TORCH_CHECK(
STD_TORCH_CHECK(
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
"b_scale_group_shape must be [128, 128].");
}
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales);
}
}

View File

@@ -0,0 +1,52 @@
#pragma once
#include <torch/csrc/stable/tensor.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
void cutlass_scaled_mm_blockwise_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales);
} // namespace vllm

View File

@@ -0,0 +1,24 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm100_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm100_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab {
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB;
@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
template <typename InType, typename OutType, bool EnableBias,
typename... EpilogueArgs>
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
EpilogueArgs&&... args) {
inline void cutlass_gemm_sm100_fp8_dispatch(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm100_fp8_config_default<InType, OutType,
@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
}
template <bool EnableBias, typename... EpilogueArgs>
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, EnableBias>(
out, a, b, a_scales, b_scales,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, EnableBias>(
out, a, b, a_scales, b_scales,

View File

@@ -0,0 +1,25 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm120_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
int M = a.size(0);
@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);

View File

@@ -0,0 +1,24 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
@@ -235,8 +237,9 @@ struct sm90_fp8_config_M16_N8192 {
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_gemm_caller_sm90_fp8(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB;
@@ -280,15 +283,15 @@ void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
template <typename InType, typename OutType, bool EnableBias,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
EpilogueArgs&&... args) {
inline void cutlass_gemm_sm90_fp8_dispatch(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm90_fp8_config_default<InType, OutType,
@@ -347,22 +350,24 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
}
template <bool EnableBias, typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, EnableBias>(
out, a, b, a_scales, b_scales,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, EnableBias>(
out, a, b, a_scales, b_scales,

View File

@@ -0,0 +1,25 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm90_int8(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
@@ -87,13 +89,13 @@ struct sm90_int8_config_M32_NSmall {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm90_int8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass3xGemmDefault =
typename sm90_int8_config_default<InType, OutType,
@@ -142,19 +144,19 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
void cutlass_scaled_mm_sm90_int8_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.dtype() == torch::kBFloat16) {
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}

View File

@@ -1,10 +1,10 @@
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "libtorch_stable/torch_utils.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
}
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
<<<1, num_experts, 0, stream>>>( \
static_cast<int64_t*>(expert_offsets.data_ptr()), \
@@ -51,32 +51,39 @@ __global__ void get_group_gemm_starts(
namespace {
void run_get_group_gemm_starts(
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
STD_TORCH_CHECK(a_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b_tensors.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
// expect int64_t to avoid overflow during offset calculations
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
torch::headeronly::ScalarType::Long);
int num_experts = static_cast<int>(expert_offsets.size(0));
bool per_act_token = a_scales.numel() != 1;
bool per_out_ch = b_scales.numel() != num_experts;
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
if (false) {
}
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
cutlass::bfloat16_t)
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
}
} // namespace
} // namespace

View File

@@ -6,6 +6,7 @@
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include <torch/csrc/stable/ops.h>
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "get_group_starts.cuh"
@@ -84,13 +85,17 @@ struct cutlass_3x_group_gemm {
};
template <typename Gemm>
void cutlass_group_gemm_caller(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB;
@@ -98,16 +103,20 @@ void cutlass_group_gemm_caller(
int num_experts = static_cast<int>(expert_offsets.size(0));
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
auto options_int =
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
auto device = a_tensors.device();
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::stable::Tensor a_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor out_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
@@ -156,7 +165,7 @@ void cutlass_group_gemm_caller(
static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())};
int device_id = a_tensors.device().index();
int device_id = a_tensors.get_device_index();
static const cutlass::KernelHardwareInfo hw_info{
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
device_id)};
@@ -170,9 +179,9 @@ void cutlass_group_gemm_caller(
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, device);
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);

View File

@@ -1,7 +1,8 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"
@@ -62,21 +63,27 @@ struct sm100_fp8_config_N8192 {
};
template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
@@ -107,14 +114,18 @@ void run_cutlass_moe_mm_sm100(
}
} // namespace
void dispatch_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
@@ -127,13 +138,17 @@ void dispatch_moe_mm_sm100(
}
}
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);

View File

@@ -1,7 +1,8 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"
@@ -103,21 +104,27 @@ struct sm90_fp8_config_N8192 {
};
template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
STD_TORCH_CHECK(
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
@@ -163,14 +170,18 @@ void run_cutlass_moe_mm_sm90(
}
}
void dispatch_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
@@ -185,13 +196,17 @@ void dispatch_moe_mm_sm90(
} // namespace
void cutlass_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);

View File

@@ -1,9 +1,11 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "libtorch_stable/torch_utils.h"
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/headeronly/core/ScalarType.h>
#include "dispatch_utils.h"
#include "libtorch_stable/dispatch_utils.h"
#include <iostream>
@@ -110,19 +112,22 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
}
namespace {
inline void launch_compute_problem_sizes(
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
const bool swap_ab, const bool is_gated) {
inline void launch_compute_problem_sizes(const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& atomic_buffer,
int64_t num_experts, int64_t n,
int64_t k, cudaStream_t stream,
const bool swap_ab,
const bool is_gated) {
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
auto const* topk_ptr = topk_ids.const_data_ptr<int32_t>();
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
auto* atomic_ptr = atomic_buffer.mutable_data_ptr<int32_t>();
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
@@ -171,46 +176,53 @@ __global__ void compute_problem_sizes_from_expert_offsets(
}
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab) {
TORCH_CHECK(expert_first_token_offset.is_cuda(),
"expert_first_token_offset must be a CUDA tensor");
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
"expert_first_token_offset must be int64");
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab) {
STD_TORCH_CHECK(expert_first_token_offset.is_cuda(),
"expert_first_token_offset must be a CUDA tensor");
STD_TORCH_CHECK(expert_first_token_offset.scalar_type() ==
torch::headeronly::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
"problem_sizes must be CUDA tensors");
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
problem_sizes2.dtype() == torch::kInt32,
"problem_sizes must be int32");
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
"problem_sizes must be contiguous");
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
"problem_sizes must be 2D tensors");
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
"problem_sizes second dim must be 3");
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
"problem_sizes1 and problem_sizes2 must have same shape");
STD_TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
"problem_sizes must be CUDA tensors");
STD_TORCH_CHECK(
problem_sizes1.scalar_type() == torch::headeronly::ScalarType::Int &&
problem_sizes2.scalar_type() == torch::headeronly::ScalarType::Int,
"problem_sizes must be int32");
STD_TORCH_CHECK(
problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
"problem_sizes must be contiguous");
STD_TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
"problem_sizes must be 2D tensors");
STD_TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
"problem_sizes second dim must be 3");
STD_TORCH_CHECK(problem_sizes1.size(0) == problem_sizes2.size(0) &&
problem_sizes1.size(1) == problem_sizes2.size(1),
"problem_sizes1 and problem_sizes2 must have same shape");
int64_t const num_experts64 = problem_sizes1.size(0);
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
"expert_first_token_offset must have num_experts + 1 elements");
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
STD_TORCH_CHECK(
expert_first_token_offset.numel() == num_experts64 + 1,
"expert_first_token_offset must have num_experts + 1 elements");
STD_TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
STD_TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX,
"n and k must fit in int32");
int const num_experts = static_cast<int>(num_experts64);
auto stream = at::cuda::getCurrentCUDAStream(
expert_first_token_offset.device().index());
auto stream =
get_current_cuda_stream(expert_first_token_offset.get_device_index());
int const threads = (num_experts < 256) ? num_experts : 256;
int const blocks = (num_experts + threads - 1) / threads;
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
auto const* offsets_ptr = expert_first_token_offset.const_data_ptr<int64_t>();
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
compute_problem_sizes_from_expert_offsets<SwapAB>
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
num_experts, static_cast<int>(n),
@@ -219,16 +231,19 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
}
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets,
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 =
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
auto device = topk_ids.device();
auto stream = get_current_cuda_stream(device.index());
torch::stable::Tensor atomic_buffer = torch::stable::new_zeros(
topk_ids, {num_experts}, torch::headeronly::ScalarType::Int);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
@@ -290,11 +305,13 @@ __global__ void compute_batched_moe_data(
}
void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
auto stream = get_current_cuda_stream(expert_offsets.get_device_index());
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
@@ -311,4 +328,4 @@ void get_cutlass_batched_moe_mm_data_caller(
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
}
}

View File

@@ -0,0 +1,220 @@
#include <stddef.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/core/ScalarType.h>
#include "cutlass/cutlass.h"
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm75(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm80(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (bias) {
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
"currently bias dtype must match output dtype ",
out.scalar_type());
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}

View File

@@ -1,8 +1,9 @@
#pragma once
#include <stddef.h>
#include <torch/all.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <ATen/cuda/CUDAContext.h>
#include "libtorch_stable/torch_utils.h"
// clang-format will break include orders
// clang-format off
@@ -95,8 +96,9 @@ struct cutlass_2x_gemm {
};
template <typename Gemm, typename... EpilogueArgs>
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_caller(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
@@ -149,11 +151,12 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto device = a.device();
auto workspace =
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
std::nullopt, device);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
auto stream = get_current_cuda_stream(device.index());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
@@ -161,9 +164,9 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
}
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void fallback_cutlass_gemm_caller(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
@@ -180,8 +183,8 @@ inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...);
} else {
TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
STD_TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh"
/**
@@ -70,13 +72,13 @@ struct sm75_config_M32 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm75_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass2xGemmDefault =
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh"
/**
@@ -72,13 +74,13 @@ struct sm80_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm80_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
using Cutlass2xGemmDefault =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
@@ -34,10 +36,12 @@ struct sm89_fp8_config_default {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -84,10 +88,12 @@ struct sm89_fp8_config_M256 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -125,10 +131,12 @@ struct sm89_fp8_config_M128 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -173,10 +181,12 @@ struct sm89_fp8_config_M64 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -227,10 +237,12 @@ struct sm89_fp8_config_M32 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -280,10 +292,12 @@ struct sm89_fp8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
@@ -326,13 +340,15 @@ struct sm89_fp8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm89_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
STD_TORCH_CHECK(a.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
STD_TORCH_CHECK(b.scalar_type() ==
torch::headeronly::ScalarType::Float8_e4m3fn);
uint32_t const m = a.size(0);
uint32_t const mp2 =

View File

@@ -1,5 +1,7 @@
#pragma once
#include <torch/headeronly/util/shim_utils.h>
#include "scaled_mm_c2x.cuh"
/**
@@ -32,10 +34,11 @@ struct sm89_int8_config_default {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -88,10 +91,11 @@ struct sm89_int8_config_M256 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -143,10 +147,11 @@ struct sm89_int8_config_M128 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -193,10 +198,11 @@ struct sm89_int8_config_M64 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -234,10 +240,11 @@ struct sm89_int8_config_M32 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -276,10 +283,11 @@ struct sm89_int8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static void dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
@@ -311,13 +319,13 @@ struct sm89_int8_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
inline void cutlass_gemm_sm89_int8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
uint32_t const m = a.size(0);
uint32_t const mp2 =

View File

@@ -8,11 +8,12 @@
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm100_fp8,
nullptr, // int8 not supported on SM100

View File

@@ -8,11 +8,12 @@
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm120_fp8,
nullptr, // int8 not supported on SM120

View File

@@ -0,0 +1,38 @@
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
*/
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
}
void cutlass_scaled_mm_azp_sm90(
torch::stable::Tensor& out, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
STD_TORCH_CHECK(a_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
STD_TORCH_CHECK(b_scales.scalar_type() ==
torch::headeronly::ScalarType::Float);
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
azp, bias);
}
#endif

View File

@@ -0,0 +1,451 @@
#include <cudaTypedefs.h>
#include <torch/csrc/stable/tensor.h>
#include "libtorch_stable/torch_utils.h"
#include "cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias);
#endif
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab);
void get_cutlass_batched_moe_mm_data_caller(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(
torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} else if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif
return false;
}
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
// or CUDA 12.8 and SM100 (Blackwell)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12030;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
// Check for strides and alignment
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
if (bias) {
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
if (version_num >= 120) {
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if (version_num >= 100 && version_num < 120) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 75) {
// Turing
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
torch::stable::Tensor const& a_tensors,
torch::stable::Tensor const& b_tensors,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& expert_offsets,
torch::stable::Tensor const& problem_sizes,
torch::stable::Tensor const& a_strides,
torch::stable::Tensor const& b_strides,
torch::stable::Tensor const& c_strides, bool per_act_token,
bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100 && version_num < 110) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
if (version_num >= 90 && version_num < 100) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90 or 100");
}
void get_cutlass_moe_mm_data(
const torch::stable::Tensor& topk_ids,
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
torch::stable::Tensor& input_permutation,
torch::stable::Tensor& output_permutation, const int64_t num_experts,
const int64_t n, const int64_t k,
const std::optional<torch::stable::Tensor>& blockscale_offsets,
const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
blockscale_offsets, is_gated);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::stable::Tensor& expert_first_token_offset,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
const bool swap_ab) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
"no cutlass_scaled_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_batched_moe_mm_data(
torch::stable::Tensor& expert_offsets,
torch::stable::Tensor& problem_sizes1,
torch::stable::Tensor& problem_sizes2,
const torch::stable::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_batched_moe_mm_data: no "
"cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
torch::stable::Tensor const& a,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales,
torch::stable::Tensor const& azp_adj,
std::optional<torch::stable::Tensor> const& azp,
std::optional<torch::stable::Tensor> const& bias) {
// Checks for conformality
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if (bias) {
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
}
if (azp) {
STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
}
STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
// azp & bias types
STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(!azp ||
azp->scalar_type() == torch::headeronly::ScalarType::Int);
STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
"currently bias dtype must match output dtype ",
c.scalar_type());
const torch::stable::accelerator::DeviceGuard device_guard(
a.get_device_index());
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
// Turing
STD_TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
STD_TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@@ -31,6 +31,78 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()");
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
// Check if cutlass grouped gemm is supported for CUDA devices of the given
// capability
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
// CUTLASS w8a8 grouped GEMM
ops.def(
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM. It takes topk_ids as an input, and computes expert_offsets
// (token start indices of each expert). In addition to this, it computes
// problem sizes for each expert's multiplication used by the two mms called
// from fused MoE operation, and arrays with permutations required to shuffle
// and de-shuffle the input/output of the fused operation.
ops.def(
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets, "
" bool is_gated) -> ()");
// compute per-expert problem sizes from expert_first_token_offset
// produced by vLLM's moe_permute kernel
ops.def(
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
" Tensor expert_first_token_offset, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int n, int k, bool swap_ab) -> ()");
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM in batched expert format. It takes expert_num_tokens
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()");
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
#endif
}
@@ -46,6 +118,31 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
TORCH_BOX(&per_token_group_quant_8bit_packed));
ops.impl("per_token_group_quant_int8",
TORCH_BOX(&per_token_group_quant_int8));
// CUTLASS scaled_mm ops
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
ops.impl("get_cutlass_batched_moe_mm_data",
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
#endif
}
// These capability-check functions take only primitive args (no tensors), so
// there is no device to dispatch on. CompositeExplicitAutograd makes them
// available for all backends. This is the stable ABI equivalent of calling
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
#ifndef USE_ROCM
ops.impl("cutlass_scaled_mm_supports_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
ops.impl("cutlass_group_gemm_supported",
TORCH_BOX(&cutlass_group_gemm_supported));
ops.impl("cutlass_scaled_mm_supports_block_fp8",
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
#endif
}

View File

@@ -1,10 +1,16 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/accelerator.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/shim_utils.h>
#include <cuda_runtime.h>
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
// Utility to get the current CUDA stream for a given device using stable APIs.
// Returns a cudaStream_t for use in kernel launches.
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {

View File

@@ -228,63 +228,18 @@ int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_moe_mm(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
void cutlass_fp4_group_mm(
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);

View File

@@ -1,23 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,23 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,24 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,56 +0,0 @@
#pragma once
#include <torch/all.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm

View File

@@ -1,23 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm100_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,24 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,23 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm

View File

@@ -1,24 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm90_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@@ -1,199 +0,0 @@
#include <stddef.h>
#include <torch/all.h>
#include "cutlass/cutlass.h"
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
*/
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (azp) {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
out, a, b, a_scales, b_scales, azp_adj, bias);
}
}

View File

@@ -1,36 +0,0 @@
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
*/
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
}
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
azp, bias);
}
#endif

View File

@@ -1,420 +0,0 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
void cutlass_moe_mm_sm90(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
void cutlass_moe_mm_sm100(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
#endif
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
void get_cutlass_moe_mm_data_caller(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated);
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab);
void get_cutlass_batched_moe_mm_data_caller(
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
const int64_t k);
#endif
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
} else if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
}
#endif
return false;
}
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
// or CUDA 12.8 and SM100 (Blackwell)
#if defined CUDA_VERSION
if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12030;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
if (version_num >= 120) {
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
if (version_num >= 100 && version_num < 120) {
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90 && version_num < 100) {
// Hopper
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
if (version_num >= 75) {
// Turing
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
void cutlass_moe_mm(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100 && version_num < 110) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
if (version_num >= 90 && version_num < 100) {
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90 or 100");
}
void get_cutlass_moe_mm_data(
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
const int64_t num_experts, const int64_t n, const int64_t k,
const std::optional<torch::Tensor>& blockscale_offsets,
const bool is_gated) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, input_permutation,
output_permutation, num_experts, n, k,
blockscale_offsets, is_gated);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
"CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
const torch::Tensor& expert_first_token_offset,
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
const int64_t n, const int64_t k, const bool swap_ab) {
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
"no cutlass_scaled_mm kernel for CUDA device capability: ",
version_num, ". Required capability: 90, 100, or 120");
}
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
const torch::Tensor& expert_num_tokens,
const int64_t num_local_experts,
const int64_t padded_m, const int64_t n,
const int64_t k) {
// This function currently gets compiled only if we have a valid cutlass moe
// mm to run it for.
int32_t version_num = get_sm_version_num();
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
problem_sizes2, expert_num_tokens,
num_local_experts, padded_m, n, k);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false,
"No compiled get_cutlass_batched_moe_mm_data: no "
"cutlass_scaled_mm kernel "
"for CUDA device capability: ",
version_num,
". Required capability: 90, 100, or 120");
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
// bias, azp, azp_adj are all 1d
// bias and azp_adj have n elements, azp has m elements
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
}
if (azp) {
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
}
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
// azp & bias types
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
}

View File

@@ -439,90 +439,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" -> ()");
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
// Check if cutlass grouped gemm is supported for CUDA devices of the given
// capability
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported);
// CUTLASS w8a8 grouped GEMM
ops.def(
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
" bool per_out_ch) -> ()");
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM. It takes topk_ids as an input, and computes expert_offsets
// (token start indices of each expert). In addition to this, it computes
// problem sizes for each expert's multiplication used by the two mms called
// from fused MoE operation, and arrays with permutations required to shuffle
// and de-shuffle the input/output of the fused operation.
ops.def(
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k, Tensor? blockscale_offsets, "
" bool is_gated) -> ()");
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
// compute per-expert problem sizes from expert_first_token_offset
// produced by vLLM's moe_permute kernel
ops.def(
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
" Tensor expert_first_token_offset, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" int n, int k, bool swap_ab) -> ()");
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA,
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM in batched expert format. It takes expert_num_tokens
// as an input, and computes expert_offsets (token start indices of each
// expert). In addition to this, it computes problem sizes for each expert's
// multiplication used by the two mms called from fused MoE operation.
ops.def(
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
" Tensor! problem_sizes1, "
" Tensor! problem_sizes2, "
" Tensor expert_num_tokens, "
" int num_local_experts, int padded_m, "
" int n, int k) -> ()");
ops.impl("get_cutlass_batched_moe_mm_data", torch::kCUDA,
&get_cutlass_batched_moe_mm_data);
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops.def(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
"bool");
ops.impl("cutlass_scaled_mm_supports_block_fp8",
&cutlass_scaled_mm_supports_block_fp8);
// SM100 CUTLASS MLA decode
ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"