[3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221)
Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
@@ -3,6 +3,14 @@
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
// This header is shared by both _C (unstable ABI) and _C_stable_libtorch
|
||||
// (stable ABI) targets. When compiled under the stable ABI target,
|
||||
// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we
|
||||
// use torch::stable::Tensor instead.
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
#include <torch/csrc/stable/tensor.h>
|
||||
#endif
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
bias, and activation zero-points onto a GEMM operation using the
|
||||
@@ -15,6 +23,12 @@
|
||||
|
||||
namespace vllm::c3x {
|
||||
|
||||
#ifdef TORCH_TARGET_VERSION
|
||||
using TensorType = torch::stable::Tensor;
|
||||
#else
|
||||
using TensorType = torch::Tensor;
|
||||
#endif
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename T>
|
||||
@@ -84,7 +98,7 @@ struct ScaledEpilogueBase {
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
static auto args_from_tensor(TensorType const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
@@ -100,7 +114,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<TensorType> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
@@ -158,8 +172,8 @@ struct ScaledEpilogue
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
@@ -203,9 +217,9 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& azp_adj,
|
||||
std::optional<TensorType> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken
|
||||
EVTComputeScaleB, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
static ArgumentType prepare_args(TensorType const& a_scales,
|
||||
TensorType const& b_scales,
|
||||
TensorType const& azp_adj,
|
||||
TensorType const& azp,
|
||||
std::optional<TensorType> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
Reference in New Issue
Block a user