[Kernel] Add per-tensor and per-token AZP epilogues (#5941)
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -73,19 +73,63 @@ struct enable_sm89_to_sm90 : Kernel {
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common ScaleA and ScaleB descriptors for the
|
||||
* ScaledEpilogue and ScaledEpilogueBias classes.
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
template <typename T>
|
||||
using ColOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
template <typename T>
|
||||
using RowOrScalarLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrZeroLoad =
|
||||
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
||||
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(torch::Tensor const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
||||
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
||||
return Arguments{data_ptr, tensor.numel() != 1};
|
||||
} else {
|
||||
// it would technically work but no use case as data_ptr is never nullptr
|
||||
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
}
|
||||
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -110,8 +154,8 @@ struct ScaledEpilogue
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
@@ -131,28 +175,32 @@ struct ScaledEpilogue
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
|
||||
return evt_compute_args;
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
||||
* This bias can also be used in the per-tensor azp case, where the activation
|
||||
* zero point (azp) is used to compute an azp correction term,
|
||||
* which is folded into the bias.
|
||||
*
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
protected:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowLoad<ElementD>;
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
@@ -164,30 +212,163 @@ struct ScaledEpilogueBias
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
using BiasArgs = typename Bias::Arguments;
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
/*
|
||||
* This epilogue directly supports per-tensor azp in int32 form.
|
||||
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
||||
* term, which should already be multiplied with the scalar azp.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
|
||||
bias_args};
|
||||
return evt_compute_args;
|
||||
// This is the full AZP term, azp * J @ B, shape (1,n)
|
||||
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute float(accum - azp_adj), both operands are int32_t
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This epilogue supports per-token azp by computing and applying
|
||||
* the correction term using a rank-1 update. If the term were materialized,
|
||||
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
||||
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
||||
* point for each row of A.
|
||||
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
||||
|
||||
// Per-token azp term, shape (m,1)
|
||||
using Azp = typename SUPER::template ColLoad<int32_t>;
|
||||
|
||||
// This is the AZP adjustment term, J @ B, shape (1,n)
|
||||
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
||||
|
||||
// Compute azp * azp_adj
|
||||
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, int32_t, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAzp =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
||||
|
||||
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
||||
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::minus, float, int32_t,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeAcc =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTComputeScaleB =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
||||
EVTComputeScaleB, Bias>;
|
||||
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user