[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)
This commit is contained in:
@@ -117,10 +117,10 @@ struct cutlass_2x_gemm {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
@@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
||||
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const *>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const *>(b.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD *>(out.data_ptr());
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
auto a_scales_ptr = a_scales.data_ptr<float>();
|
||||
auto b_scales_ptr = b_scales.data_ptr<float>();
|
||||
@@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
@@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
@@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
@@ -120,10 +120,10 @@ struct cutlass_3x_gemm {
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
@@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB *>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB *>(b.data_ptr());
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD *>(out.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
@@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor &out, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
|
||||
@@ -2,29 +2,29 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
void cutlass_scaled_mm_dq_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b,
|
||||
torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales);
|
||||
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
|
||||
torch::Tensor const &b, torch::Tensor const &a_scales,
|
||||
torch::Tensor const &b_scales) {
|
||||
void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
int32_t major_capability;
|
||||
int32_t minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
@@ -36,14 +36,15 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 && b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
|
||||
Reference in New Issue
Block a user