[ROCm] [Bugfix] Fix fused_qknorm_rope_kernel rocm compatibility (#28500)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) || defined(USE_ROCM)
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||
// ROCm 7.0+ supports bfloat16
|
||||
template <>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
@@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
};
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
|
||||
// defined(USE_ROCM)
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||
// 12000))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user