[3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
@@ -6,14 +6,16 @@
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
|
||||
@@ -3,6 +3,14 @@
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
|
||||
// (stable ABI) targets. When compiled under the stable ABI target,
|
||||
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
|
||||
// use torch::stable::Tensor instead.
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#endif
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
@@ -15,6 +23,12 @@
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
using TensorType = torch::stable::Tensor;
|
||||
#else
|
||||
using TensorType = torch::Tensor;
|
||||
#endif
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
@@ -84,7 +98,7 @@ struct ScaledEpilogueBase {
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
static auto args_from_tensor(TensorType const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
@@ -100,7 +114,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<TensorType> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
@@ -158,8 +172,8 @@ struct ScaledEpilogue
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
@@ -203,9 +217,9 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& azp_adj,
|
||||
std::optional<TensorType> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& azp_adj,
|
||||
TensorType const& azp,
|
||||
std::optional<TensorType> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||
|
||||
/*
|
||||
@@ -52,7 +54,7 @@ struct ScaledEpilogueBase {
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
static auto args_from_tensor(torch::stable::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
@@ -68,7 +70,8 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(
|
||||
std::optional<torch::stable::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
@@ -117,8 +120,8 @@ struct ScaledEpilogue
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
@@ -160,9 +163,9 @@ struct ScaledEpilogueBias
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input,
|
||||
torch::stable::Tensor& output_s,
|
||||
int64_t group_size, double eps, double int8_min,
|
||||
double int8_max);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides, bool per_act_token,
|
||||
bool per_out_ch);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
#endif
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
@@ -25,14 +26,14 @@
|
||||
namespace vllm::c3x {
|
||||
|
||||
static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||
torch::Tensor const& a, torch::Tensor const& b) {
|
||||
torch::stable::Tensor const& a, torch::stable::Tensor const& b) {
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
return {m, n, k, 1};
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(
|
||||
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||
torch::stable::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
@@ -50,19 +51,20 @@ void cutlass_gemm_caller(
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(device);
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
@@ -4,13 +4,12 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_azp_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<
|
||||
c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj,
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
@@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
@@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
@@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
int M = a.size(0);
|
||||
if (M <= 256) {
|
||||
using Gemm = typename sm120_blockwise_fp8_config_M64<OutType>::Gemm;
|
||||
@@ -0,0 +1,23 @@
|
||||
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
@@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
@@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
|
||||
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||
STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||
|
||||
StrideA a_stride;
|
||||
StrideB b_stride;
|
||||
@@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
// TODO: better heuristics
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||
@@ -1,52 +1,57 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
|
||||
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias,
|
||||
void dispatch_scaled_mm(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias,
|
||||
Fp8Func fp8_func, Int8Func int8_func,
|
||||
BlockwiseFunc blockwise_func) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
|
||||
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) {
|
||||
fp8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
||||
int8_func(c, a, b, a_scales, b_scales, bias);
|
||||
} else {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
false, "Int8 not supported on SM", version_num,
|
||||
". Use FP8 quantization instead, or run on older arch (SM < 100).");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 90) {
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
a.size(0) == a_scales.size(0) &&
|
||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
||||
"a_scale_group_shape must be [1, 128].");
|
||||
TORCH_CHECK(
|
||||
STD_TORCH_CHECK(
|
||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
||||
"b_scale_group_shape must be [128, 128].");
|
||||
}
|
||||
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
blockwise_func(c, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales);
|
||||
} // namespace vllm
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
@@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
EpilogueArgs&&... args) {
|
||||
inline void cutlass_gemm_sm100_fp8_dispatch(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm100_fp8_config_default<InType, OutType,
|
||||
@@ -292,22 +295,24 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
}
|
||||
|
||||
template <bool EnableBias, typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
@@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
int M = a.size(0);
|
||||
|
||||
@@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
@@ -235,8 +237,9 @@ struct sm90_fp8_config_M16_N8192 {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_gemm_caller_sm90_fp8(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -280,15 +283,15 @@ void cutlass_gemm_caller_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
template <typename InType, typename OutType, bool EnableBias,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
EpilogueArgs&&... args) {
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config_default<InType, OutType,
|
||||
@@ -347,22 +350,24 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
}
|
||||
|
||||
template <bool EnableBias, typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
void cutlass_scaled_mm_sm90_fp8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, EnableBias>(
|
||||
out, a, b, a_scales, b_scales,
|
||||
@@ -0,0 +1,25 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
@@ -87,13 +89,13 @@ struct sm90_int8_config_M32_NSmall {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_int8_config_default<InType, OutType,
|
||||
@@ -142,19 +144,19 @@ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
void cutlass_scaled_mm_sm90_int8_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
@@ -31,7 +31,7 @@ __global__ void get_group_gemm_starts(
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
else if (out_tensors.scalar_type() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
@@ -51,32 +51,39 @@ __global__ void get_group_gemm_starts(
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
torch::stable::Tensor const& expert_offsets, torch::stable::Tensor& a_ptrs,
|
||||
torch::stable::Tensor& b_ptrs, torch::stable::Tensor& out_ptrs,
|
||||
torch::stable::Tensor& a_scales_ptrs, torch::stable::Tensor& b_scales_ptrs,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors, torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales) {
|
||||
STD_TORCH_CHECK(a_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b_tensors.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
STD_TORCH_CHECK(expert_offsets.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::BFloat16,
|
||||
cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::headeronly::ScalarType::Half, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
STD_TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
@@ -84,13 +85,17 @@ struct cutlass_3x_group_gemm {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_group_gemm_caller(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_group_gemm_caller(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
@@ -98,16 +103,20 @@ void cutlass_group_gemm_caller(
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
auto stream = get_current_cuda_stream(a_tensors.get_device_index());
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
auto device = a_tensors.device();
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::stable::Tensor a_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor out_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor a_scales_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
torch::stable::Tensor b_scales_ptrs = torch::stable::empty(
|
||||
{num_experts}, torch::headeronly::ScalarType::Long, std::nullopt, device);
|
||||
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||
@@ -156,7 +165,7 @@ void cutlass_group_gemm_caller(
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
int device_id = a_tensors.device().index();
|
||||
int device_id = a_tensors.get_device_index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
@@ -170,9 +179,9 @@ void cutlass_group_gemm_caller(
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
@@ -1,7 +1,8 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
@@ -62,21 +63,27 @@ struct sm100_fp8_config_N8192 {
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
void run_cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
@@ -107,14 +114,18 @@ void run_cutlass_moe_mm_sm100(
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void dispatch_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
void dispatch_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
@@ -127,13 +138,17 @@ void dispatch_moe_mm_sm100(
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
@@ -1,7 +1,8 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
@@ -103,21 +104,27 @@ struct sm90_fp8_config_N8192 {
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
void run_cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
STD_TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
STD_TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
STD_TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
a_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
STD_TORCH_CHECK(
|
||||
b_tensors.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
@@ -163,14 +170,18 @@ void run_cutlass_moe_mm_sm90(
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
void dispatch_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
@@ -185,13 +196,17 @@ void dispatch_moe_mm_sm90(
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
@@ -1,9 +1,11 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "libtorch_stable/dispatch_utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -110,19 +112,22 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n, int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab, const bool is_gated) {
|
||||
inline void launch_compute_problem_sizes(const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab,
|
||||
const bool is_gated) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
auto const* topk_ptr = topk_ids.data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.data_ptr<int32_t>();
|
||||
auto const* topk_ptr = topk_ids.const_data_ptr<int32_t>();
|
||||
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||
auto* atomic_ptr = atomic_buffer.mutable_data_ptr<int32_t>();
|
||||
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes<SwapAB><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
@@ -171,46 +176,53 @@ __global__ void compute_problem_sizes_from_expert_offsets(
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
TORCH_CHECK(expert_first_token_offset.dtype() == torch::kInt64,
|
||||
"expert_first_token_offset must be int64");
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab) {
|
||||
STD_TORCH_CHECK(expert_first_token_offset.is_cuda(),
|
||||
"expert_first_token_offset must be a CUDA tensor");
|
||||
STD_TORCH_CHECK(expert_first_token_offset.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Long,
|
||||
"expert_first_token_offset must be int64");
|
||||
|
||||
TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
TORCH_CHECK(problem_sizes1.dtype() == torch::kInt32 &&
|
||||
problem_sizes2.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
TORCH_CHECK(problem_sizes1.sizes() == problem_sizes2.sizes(),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
STD_TORCH_CHECK(problem_sizes1.is_cuda() && problem_sizes2.is_cuda(),
|
||||
"problem_sizes must be CUDA tensors");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes1.scalar_type() == torch::headeronly::ScalarType::Int &&
|
||||
problem_sizes2.scalar_type() == torch::headeronly::ScalarType::Int,
|
||||
"problem_sizes must be int32");
|
||||
STD_TORCH_CHECK(
|
||||
problem_sizes1.is_contiguous() && problem_sizes2.is_contiguous(),
|
||||
"problem_sizes must be contiguous");
|
||||
STD_TORCH_CHECK(problem_sizes1.dim() == 2 && problem_sizes2.dim() == 2,
|
||||
"problem_sizes must be 2D tensors");
|
||||
STD_TORCH_CHECK(problem_sizes1.size(1) == 3 && problem_sizes2.size(1) == 3,
|
||||
"problem_sizes second dim must be 3");
|
||||
STD_TORCH_CHECK(problem_sizes1.size(0) == problem_sizes2.size(0) &&
|
||||
problem_sizes1.size(1) == problem_sizes2.size(1),
|
||||
"problem_sizes1 and problem_sizes2 must have same shape");
|
||||
|
||||
int64_t const num_experts64 = problem_sizes1.size(0);
|
||||
TORCH_CHECK(expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX, "n and k must fit in int32");
|
||||
STD_TORCH_CHECK(
|
||||
expert_first_token_offset.numel() == num_experts64 + 1,
|
||||
"expert_first_token_offset must have num_experts + 1 elements");
|
||||
STD_TORCH_CHECK(num_experts64 <= INT32_MAX, "num_experts must fit in int32");
|
||||
STD_TORCH_CHECK(n <= INT32_MAX && k <= INT32_MAX,
|
||||
"n and k must fit in int32");
|
||||
|
||||
int const num_experts = static_cast<int>(num_experts64);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(
|
||||
expert_first_token_offset.device().index());
|
||||
auto stream =
|
||||
get_current_cuda_stream(expert_first_token_offset.get_device_index());
|
||||
|
||||
int const threads = (num_experts < 256) ? num_experts : 256;
|
||||
int const blocks = (num_experts + threads - 1) / threads;
|
||||
|
||||
auto const* offsets_ptr = expert_first_token_offset.data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.data_ptr<int32_t>();
|
||||
auto const* offsets_ptr = expert_first_token_offset.const_data_ptr<int64_t>();
|
||||
auto* ps1_ptr = problem_sizes1.mutable_data_ptr<int32_t>();
|
||||
auto* ps2_ptr = problem_sizes2.mutable_data_ptr<int32_t>();
|
||||
|
||||
VLLM_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
VLLM_STABLE_DISPATCH_BOOL(swap_ab, SwapAB, [&] {
|
||||
compute_problem_sizes_from_expert_offsets<SwapAB>
|
||||
<<<blocks, threads, 0, stream>>>(offsets_ptr, ps1_ptr, ps2_ptr,
|
||||
num_experts, static_cast<int>(n),
|
||||
@@ -219,16 +231,19 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
auto device = topk_ids.device();
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
torch::stable::Tensor atomic_buffer = torch::stable::new_zeros(
|
||||
topk_ids, {num_experts}, torch::headeronly::ScalarType::Int);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
@@ -290,11 +305,13 @@ __global__ void compute_batched_moe_data(
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
auto stream = get_current_cuda_stream(expert_offsets.get_device_index());
|
||||
|
||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||
compute_batched_moe_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||
@@ -311,4 +328,4 @@ void get_cutlass_batched_moe_mm_data_caller(
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
}
|
||||
}
|
||||
220
csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu
Normal file
220
csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu
Normal file
@@ -0,0 +1,220 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
*/
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm75_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm80_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm89_epilogue(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.scalar_type() == torch::headeronly::ScalarType::Char) {
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) {
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half);
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
out.scalar_type());
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
#pragma once
|
||||
#include <stddef.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
@@ -95,8 +96,9 @@ struct cutlass_2x_gemm {
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
@@ -149,11 +151,12 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
auto device = a.device();
|
||||
auto workspace =
|
||||
torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte,
|
||||
std::nullopt, device);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
auto stream = get_current_cuda_stream(device.index());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
@@ -161,9 +164,9 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
}
|
||||
|
||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void fallback_cutlass_gemm_caller(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
// In some cases, the GPU isn't able to accommodate the
|
||||
// shared memory requirements of the Gemm. In such cases, use
|
||||
@@ -180,8 +183,8 @@ inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||
std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
STD_TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
return cutlass_gemm_caller<FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -70,13 +72,13 @@ struct sm75_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm75_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -72,13 +74,13 @@ struct sm80_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm80_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
@@ -34,10 +36,12 @@ struct sm89_fp8_config_default {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -84,10 +88,12 @@ struct sm89_fp8_config_M256 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -125,10 +131,12 @@ struct sm89_fp8_config_M128 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -173,10 +181,12 @@ struct sm89_fp8_config_M64 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -227,10 +237,12 @@ struct sm89_fp8_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -280,10 +292,12 @@ struct sm89_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
@@ -326,13 +340,15 @@ struct sm89_fp8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
STD_TORCH_CHECK(a.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
STD_TORCH_CHECK(b.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float8_e4m3fn);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
@@ -32,10 +34,11 @@ struct sm89_int8_config_default {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -88,10 +91,11 @@ struct sm89_int8_config_M256 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -143,10 +147,11 @@ struct sm89_int8_config_M128 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -193,10 +198,11 @@ struct sm89_int8_config_M64 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -234,10 +240,11 @@ struct sm89_int8_config_M32 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -276,10 +283,11 @@ struct sm89_int8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static void dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
@@ -311,13 +319,13 @@ struct sm89_int8_config_M16 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::stable::Tensor& out,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
STD_TORCH_CHECK(b.scalar_type() == torch::headeronly::ScalarType::Char);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm100_fp8,
|
||||
nullptr, // int8 not supported on SM100
|
||||
@@ -8,11 +8,12 @@
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm120_fp8,
|
||||
nullptr, // int8 not supported on SM120
|
||||
@@ -0,0 +1,38 @@
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm90_fp8,
|
||||
vllm::cutlass_scaled_mm_sm90_int8,
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(
|
||||
torch::stable::Tensor& out, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
STD_TORCH_CHECK(a_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
STD_TORCH_CHECK(b_scales.scalar_type() ==
|
||||
torch::headeronly::ScalarType::Float);
|
||||
|
||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,451 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
|
||||
#include "libtorch_stable/torch_utils.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
void cutlass_moe_mm_sm90(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
void cutlass_moe_mm_sm100(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(
|
||||
torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
} else if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
||||
// or CUDA 12.8 and SM100 (Blackwell)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
}
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::stable::Tensor& c, torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
if (version_num >= 120) {
|
||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100 && version_num < 120) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 75) {
|
||||
// Turing
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_moe_mm(torch::stable::Tensor& out_tensors,
|
||||
torch::stable::Tensor const& a_tensors,
|
||||
torch::stable::Tensor const& b_tensors,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& expert_offsets,
|
||||
torch::stable::Tensor const& problem_sizes,
|
||||
torch::stable::Tensor const& a_strides,
|
||||
torch::stable::Tensor const& b_strides,
|
||||
torch::stable::Tensor const& c_strides, bool per_act_token,
|
||||
bool per_out_ch) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
if (version_num >= 100 && version_num < 110) {
|
||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::stable::Tensor& topk_ids,
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
torch::stable::Tensor& input_permutation,
|
||||
torch::stable::Tensor& output_permutation, const int64_t num_experts,
|
||||
const int64_t n, const int64_t k,
|
||||
const std::optional<torch::stable::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
blockscale_offsets, is_gated);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::stable::Tensor& expert_first_token_offset,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k,
|
||||
const bool swap_ab) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
||||
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(
|
||||
torch::stable::Tensor& expert_offsets,
|
||||
torch::stable::Tensor& problem_sizes1,
|
||||
torch::stable::Tensor& problem_sizes2,
|
||||
const torch::stable::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
return;
|
||||
#endif
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_batched_moe_mm_data: no "
|
||||
"cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::stable::Tensor& c,
|
||||
torch::stable::Tensor const& a,
|
||||
torch::stable::Tensor const& b,
|
||||
torch::stable::Tensor const& a_scales,
|
||||
torch::stable::Tensor const& b_scales,
|
||||
torch::stable::Tensor const& azp_adj,
|
||||
std::optional<torch::stable::Tensor> const& azp,
|
||||
std::optional<torch::stable::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
STD_TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
STD_TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
STD_TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
STD_TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
STD_TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
STD_TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
STD_TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
// bias, azp, azp_adj are all 1d
|
||||
// bias and azp_adj have n elements, azp has m elements
|
||||
if (bias) {
|
||||
STD_TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
STD_TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
STD_TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
STD_TORCH_CHECK(azp_adj.scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(!azp ||
|
||||
azp->scalar_type() == torch::headeronly::ScalarType::Int);
|
||||
STD_TORCH_CHECK(!bias || bias->scalar_type() == c.scalar_type(),
|
||||
"currently bias dtype must match output dtype ",
|
||||
c.scalar_type());
|
||||
|
||||
const torch::stable::accelerator::DeviceGuard device_guard(
|
||||
a.get_device_index());
|
||||
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
STD_TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
STD_TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@@ -31,6 +31,78 @@ STABLE_TORCH_LIBRARY_FRAGMENT(_C, ops) {
|
||||
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
|
||||
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
|
||||
"()");
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
|
||||
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
||||
|
||||
// CUTLASS w8a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
||||
" bool per_out_ch) -> ()");
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
||||
// (token start indices of each expert). In addition to this, it computes
|
||||
// problem sizes for each expert's multiplication used by the two mms called
|
||||
// from fused MoE operation, and arrays with permutations required to shuffle
|
||||
// and de-shuffle the input/output of the fused operation.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k, Tensor? blockscale_offsets, "
|
||||
" bool is_gated) -> ()");
|
||||
|
||||
// compute per-expert problem sizes from expert_first_token_offset
|
||||
// produced by vLLM's moe_permute kernel
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
||||
" Tensor expert_first_token_offset, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int n, int k, bool swap_ab) -> ()");
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM in batched expert format. It takes expert_num_tokens
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
// expert). In addition to this, it computes problem sizes for each expert's
|
||||
// multiplication used by the two mms called from fused MoE operation.
|
||||
ops.def(
|
||||
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" Tensor expert_num_tokens, "
|
||||
" int num_local_experts, int padded_m, "
|
||||
" int n, int k) -> ()");
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -46,6 +118,31 @@ STABLE_TORCH_LIBRARY_IMPL(_C, CUDA, ops) {
|
||||
TORCH_BOX(&per_token_group_quant_8bit_packed));
|
||||
ops.impl("per_token_group_quant_int8",
|
||||
TORCH_BOX(&per_token_group_quant_int8));
|
||||
|
||||
// CUTLASS scaled_mm ops
|
||||
ops.impl("cutlass_scaled_mm", TORCH_BOX(&cutlass_scaled_mm));
|
||||
ops.impl("cutlass_scaled_mm_azp", TORCH_BOX(&cutlass_scaled_mm_azp));
|
||||
ops.impl("cutlass_moe_mm", TORCH_BOX(&cutlass_moe_mm));
|
||||
ops.impl("get_cutlass_moe_mm_data", TORCH_BOX(&get_cutlass_moe_mm_data));
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets",
|
||||
TORCH_BOX(&get_cutlass_moe_mm_problem_sizes_from_expert_offsets));
|
||||
ops.impl("get_cutlass_batched_moe_mm_data",
|
||||
TORCH_BOX(&get_cutlass_batched_moe_mm_data));
|
||||
#endif
|
||||
}
|
||||
|
||||
// These capability-check functions take only primitive args (no tensors), so
|
||||
// there is no device to dispatch on. CompositeExplicitAutograd makes them
|
||||
// available for all backends. This is the stable ABI equivalent of calling
|
||||
// ops.impl("op_name", &func) without a dispatch key in the non-stable API.
|
||||
STABLE_TORCH_LIBRARY_IMPL(_C, CompositeExplicitAutograd, ops) {
|
||||
#ifndef USE_ROCM
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_fp8));
|
||||
ops.impl("cutlass_group_gemm_supported",
|
||||
TORCH_BOX(&cutlass_group_gemm_supported));
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
TORCH_BOX(&cutlass_scaled_mm_supports_block_fp8));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/stable/accelerator.h>
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
|
||||
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
|
||||
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
|
||||
|
||||
// Utility to get the current CUDA stream for a given device using stable APIs.
|
||||
// Returns a cudaStream_t for use in kernel launches.
|
||||
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
|
||||
|
||||
45
csrc/ops.h
45
csrc/ops.h
@@ -228,63 +228,18 @@ int64_t ggml_moe_get_block_size(int64_t type);
|
||||
#ifndef USE_ROCM
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
|
||||
void cutlass_fp4_group_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,
|
||||
const torch::Tensor& alphas, const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts,
|
||||
const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
|
||||
torch::Tensor const& input, torch::Tensor const& input_scale,
|
||||
bool is_sf_swizzled_layout);
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,23 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,24 +0,0 @@
|
||||
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm90_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
} // namespace vllm
|
||||
@@ -1,23 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,24 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,23 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_fp8_dispatch.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<true>(out, a, b, a_scales,
|
||||
b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_fp8_epilogue<false>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,24 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm90_int8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm90_int8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@@ -1,199 +0,0 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/all.h>
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
||||
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
*/
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
if (azp) {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
||||
out, a, b, a_scales, b_scales, azp_adj, bias);
|
||||
}
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm90_fp8,
|
||||
vllm::cutlass_scaled_mm_sm90_int8,
|
||||
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,420 +0,0 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if (defined(ENABLE_CUTLASS_MOE_SM90) && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined(ENABLE_CUTLASS_MOE_SM120) && ENABLE_CUTLASS_MOE_SM120)
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab);
|
||||
|
||||
void get_cutlass_batched_moe_mm_data_caller(
|
||||
torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts, const int64_t padded_m, const int64_t n,
|
||||
const int64_t k);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS FP8 kernels need at least
|
||||
// CUDA 12.0 on SM90 systems (Hopper)
|
||||
// CUDA 12.4 on SM89 systems (Lovelace)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// CUTLASS block-quantized FP8 kernels need at least CUDA 12.0
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
} else if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
||||
// or CUDA 12.8 and SM100 (Blackwell)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
}
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
if (version_num >= 120) {
|
||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100 && version_num < 120) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 75) {
|
||||
// Turing
|
||||
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
if (version_num >= 100 && version_num < 110) {
|
||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets,
|
||||
const bool is_gated) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
blockscale_offsets, is_gated);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
const torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
const int64_t n, const int64_t k, const bool swap_ab) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, n, k, swap_ab);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes_from_expert_offsets: "
|
||||
"no cutlass_scaled_mm kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
const torch::Tensor& expert_num_tokens,
|
||||
const int64_t num_local_experts,
|
||||
const int64_t padded_m, const int64_t n,
|
||||
const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120)
|
||||
get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No compiled get_cutlass_batched_moe_mm_data: no "
|
||||
"cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num,
|
||||
". Required capability: 90, 100, or 120");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
// bias, azp, azp_adj are all 1d
|
||||
// bias and azp_adj have n elements, azp has m elements
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
if (version_num == 89) {
|
||||
// Ada Lovelace
|
||||
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
if (version_num >= 80) {
|
||||
// Ampere
|
||||
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
}
|
||||
|
||||
// Turing
|
||||
TORCH_CHECK(version_num >= 75);
|
||||
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@@ -439,90 +439,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
||||
|
||||
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass grouped gemm is supported for CUDA devices of the given
|
||||
// capability
|
||||
ops.def("cutlass_group_gemm_supported(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_group_gemm_supported", &cutlass_group_gemm_supported);
|
||||
|
||||
// CUTLASS w8a8 grouped GEMM
|
||||
ops.def(
|
||||
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
|
||||
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
|
||||
" Tensor problem_sizes, Tensor a_strides, "
|
||||
" Tensor b_strides, Tensor c_strides, bool per_act_token, "
|
||||
" bool per_out_ch) -> ()");
|
||||
ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM. It takes topk_ids as an input, and computes expert_offsets
|
||||
// (token start indices of each expert). In addition to this, it computes
|
||||
// problem sizes for each expert's multiplication used by the two mms called
|
||||
// from fused MoE operation, and arrays with permutations required to shuffle
|
||||
// and de-shuffle the input/output of the fused operation.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
|
||||
" Tensor! input_permutation, "
|
||||
" Tensor! output_permutation, int num_experts, "
|
||||
" int n, int k, Tensor? blockscale_offsets, "
|
||||
" bool is_gated) -> ()");
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// compute per-expert problem sizes from expert_first_token_offset
|
||||
// produced by vLLM's moe_permute kernel
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes_from_expert_offsets("
|
||||
" Tensor expert_first_token_offset, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int n, int k, bool swap_ab) -> ()");
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes_from_expert_offsets", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes_from_expert_offsets);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM in batched expert format. It takes expert_num_tokens
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
// expert). In addition to this, it computes problem sizes for each expert's
|
||||
// multiplication used by the two mms called from fused MoE operation.
|
||||
ops.def(
|
||||
"get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" Tensor expert_num_tokens, "
|
||||
" int num_local_experts, int padded_m, "
|
||||
" int n, int k) -> ()");
|
||||
ops.impl("get_cutlass_batched_moe_mm_data", torch::kCUDA,
|
||||
&get_cutlass_batched_moe_mm_data);
|
||||
|
||||
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
|
||||
"bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_block_fp8",
|
||||
&cutlass_scaled_mm_supports_block_fp8);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
|
||||
|
||||
Reference in New Issue
Block a user