[CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722)

This commit is contained in:
Michael Goin
2024-05-22 03:18:41 -04:00
committed by GitHub
parent 9b9a10d6cb
commit 5f6d10c14c
64 changed files with 6398 additions and 6790 deletions

View File

@@ -22,51 +22,56 @@ namespace marlin_24 {
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1,
const FragA &frag_b, FragC &frag_c, FragM &frag_m,
__device__ inline void mma_sp(const FragB& a_frag0, const FragB& a_frag1,
const FragA& frag_b, FragC& frag_c, FragM& frag_m,
const int psel) {
const uint32_t *a0 = reinterpret_cast<const uint32_t *>(&a_frag0);
const uint32_t *a1 = reinterpret_cast<const uint32_t *>(&a_frag1);
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
const uint32_t *e = reinterpret_cast<const uint32_t *>(&frag_m);
float *c = reinterpret_cast<float *>(&frag_c);
const uint32_t* a0 = reinterpret_cast<const uint32_t*>(&a_frag0);
const uint32_t* a1 = reinterpret_cast<const uint32_t*>(&a_frag1);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* e = reinterpret_cast<const uint32_t*>(&frag_m);
float* c = reinterpret_cast<float*>(&frag_c);
if (psel == 0) {
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
"f"(c[2]), "f"(c[3]), "r"(e[0]));
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
"r"(e[0]));
asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x0;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
"r"(e[0]));
} else {
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
"f"(c[2]), "f"(c[3]), "r"(e[0]));
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
"f"(c[6]), "f"(c[7]), "r"(e[0]));
asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]), "r"(b[2]),
"r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]),
"r"(e[0]));
asm volatile(
"mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
"{%12,%13,%14,%15}, %16, 0x1;\n"
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]), "r"(b[3]),
"r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]), "f"(c[6]), "f"(c[7]),
"r"(e[0]));
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut> __device__ inline int lop3(int a, int b, int c) {
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
@@ -120,11 +125,11 @@ __device__ inline FragB dequant_4bit(int q) {
const int ADD = 0xd480d480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
*reinterpret_cast<const half2 *>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
*reinterpret_cast<const half2 *>(&MUL),
*reinterpret_cast<const half2 *>(&ADD));
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
@@ -143,24 +148,24 @@ __device__ inline FragB dequant_8bit(int q) {
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3,
FragS &s0, float *c4, float *c5, float *c6,
float *c7, FragS &s1) {
__device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
FragS& s0, float* c4, float* c5, float* c6,
float* c7, FragS& s1) {
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
@@ -172,4 +177,4 @@ __device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3,
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
}
} // namespace marlin_24
} // namespace marlin_24