[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)

This commit is contained in:
Michael Goin
2024-05-22 03:18:41 -04:00
committed by GitHub
parent 9b9a10d6cb
commit 5f6d10c14c
64 changed files with 6398 additions and 6790 deletions

View File

@@ -21,41 +21,44 @@
namespace marlin_24 {
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
__device__ inline void cp_async4_pred_zfill(void *smem_ptr,
const void *glob_ptr,
__device__ inline void cp_async4_pred_zfill(void* smem_ptr,
const void* glob_ptr,
bool pred = true,
const bool zfill = false) {
const int BYTES = 16;
int src_in_bytes = (zfill ? 0 : BYTES);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
}
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
// Asynchronous global->shared copy
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
// Async copy fence.
@@ -64,22 +67,23 @@ __device__ inline void cp_async_fence() {
}
// Wait until at most `n` async copy stages are still pending.
template <int n> __device__ inline void cp_async_wait() {
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_m);
__device__ inline void ldsm4_m(FragM& frag_m, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_m);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
: "=r"(a[0]), "=r"(a[1])
@@ -88,8 +92,8 @@ __device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) {
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) {
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
__device__ inline void ldsm4_t(FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
@@ -98,7 +102,7 @@ __device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) {
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int *lock, int count) {
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
@@ -113,7 +117,7 @@ __device__ inline void barrier_acquire(int *lock, int count) {
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int *lock, bool reset = false) {
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
@@ -129,4 +133,4 @@ __device__ inline void barrier_release(int *lock, bool reset = false) {
: "l"(lock), "r"(val));
}
}
} // namespace marlin_24
} // namespace marlin_24