[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)

This commit is contained in:
Michael Goin
2024-05-22 03:18:41 -04:00
committed by GitHub
parent 9b9a10d6cb
commit 5f6d10c14c
64 changed files with 6398 additions and 6790 deletions

View File

@@ -117,10 +117,10 @@ struct cutlass_2x_gemm {
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales) {
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
@@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const *>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const *>(b.data_ptr());
auto c_ptr = static_cast<ElementD *>(out.data_ptr());
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
auto a_scales_ptr = a_scales.data_ptr<float>();
auto b_scales_ptr = b_scales.data_ptr<float>();
@@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
} // namespace
void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales) {
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
}
}
void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales) {
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
@@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
}
}
void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales) {
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

View File

@@ -120,10 +120,10 @@ struct cutlass_3x_gemm {
};
template <typename Gemm>
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales) {
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
@@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using GemmKernel = typename Gemm::GemmKernel;
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
auto a_ptr = static_cast<ElementAB *>(a.data_ptr());
auto b_ptr = static_cast<ElementAB *>(b.data_ptr());
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
b_stride};
auto c_ptr = static_cast<ElementD *>(out.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
@@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
}
} // namespace
void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales) {
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

View File

@@ -2,29 +2,29 @@
#include <cuda_runtime.h>
#include <torch/extension.h>
void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales);
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales);
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales);
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales);
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
torch::Tensor const &b, torch::Tensor const &a_scales,
torch::Tensor const &b_scales) {
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
int32_t major_capability;
int32_t minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
@@ -36,14 +36,15 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
b.size(1) == c.size(1));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));