[Bugfix]fix output Nan/Inf in marlin if dtype=float16 (#33972)

Signed-off-by: IriKa Qiu <qiujie.jq@gmail.com>
This commit is contained in:
IriKa
2026-03-28 07:36:08 +08:00
committed by GitHub
parent b69bf2f0b1
commit 148a5c1226
8 changed files with 83 additions and 55 deletions

View File

@@ -189,10 +189,7 @@ __device__ __forceinline__ void cp_async_wait<0>() {
}
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
return fminf(mmax, fmaxf(v, mmin));
#else
#endif
}
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,