[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-11-12 05:01:14 -08:00
committed by GitHub
parent c5f10cc139
commit edb59a9470
6 changed files with 37 additions and 38 deletions

View File

@@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
// ROCm 7.0+ supports bfloat16
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
@@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
// defined(USE_ROCM)
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))