[4/n] Migrate FP4/W4A8 CUTLASS kernels to torch stable ABI (#37503)

Signed-off-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
This commit is contained in:
mikaylagawarecki
2026-03-31 13:21:13 -04:00
committed by GitHub
parent 0dd25a44ea
commit 7c080dd3c5
27 changed files with 1205 additions and 1016 deletions

View File

@@ -1,6 +1,21 @@
#pragma once
#include <torch/all.h>
// This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
#ifdef TORCH_TARGET_VERSION
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/BFloat16.h>
#include <torch/headeronly/util/Half.h>
#include <torch/headeronly/util/shim_utils.h> // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
#define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
#include <torch/all.h>
using TorchTensor = torch::Tensor;
#define TORCH_UTILS_CHECK TORCH_CHECK
#endif
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
@@ -55,35 +70,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor,
static inline auto make_cute_layout(TorchTensor const& tensor,
std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
return tensor.stride(idx);
}
});
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
@@ -97,7 +112,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
template <typename Stride>
static inline auto maybe_make_cute_layout(
std::optional<torch::Tensor> const& tensor,
std::optional<TorchTensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
@@ -121,12 +136,12 @@ template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<c10::Half> {
struct equivalent_cutlass_type<torch::headeronly::Half> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<c10::BFloat16> {
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
using type = cutlass::bfloat16_t;
};
@@ -134,8 +149,8 @@ struct equivalent_cutlass_type<c10::BFloat16> {
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
@@ -146,15 +161,15 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = c10::Half;
using type = torch::headeronly::Half;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = c10::BFloat16;
using type = torch::headeronly::BFloat16;
};
// get equivalent c10::ScalarType tag from compile time type
// get equivalent torch::headeronly::ScalarType tag from compile time type
template <typename T>
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;