[CI/Build] Suppress divide-by-zero and missing return statement warnings (#7001)
This commit is contained in:
committed by
GitHub
parent
8571ac4672
commit
6e4852ce28
@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||
|
||||
return result;
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
|
||||
@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
__NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
@@ -508,7 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
return {}; // Squash missing return statement warning
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
@@ -521,7 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
return {}; // Squash missing return statement warning
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
|
||||
@@ -1130,12 +1130,12 @@ __global__ void Marlin(
|
||||
};
|
||||
|
||||
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
||||
if constexpr (has_zp) {
|
||||
// This code does not handle group_blocks == 0,
|
||||
// which signifies act_order.
|
||||
// has_zp implies AWQ, which doesn't have act_order,
|
||||
static_assert(group_blocks != 0);
|
||||
// This code does not handle group_blocks == 0,
|
||||
// which signifies act_order.
|
||||
// has_zp implies AWQ, which doesn't have act_order,
|
||||
static_assert(!has_zp || group_blocks != 0);
|
||||
|
||||
if constexpr (has_zp) {
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks == -1) {
|
||||
@@ -1161,7 +1161,13 @@ __global__ void Marlin(
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = k_blocks / group_blocks;
|
||||
int cur_group_id = 0;
|
||||
|
||||
// Suppress bogus and persistent divide-by-zero warning
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
cur_group_id = k_blocks / group_blocks;
|
||||
#pragma nv_diagnostic pop
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user