fix: guard cuda_bf16.h with __CUDA_ARCH__
This commit is contained in:
@@ -30,9 +30,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cstdint>
|
||||
|
||||
// __nv_bfloat16 is a built-in type on CUDA 13+
|
||||
// cuda_bf16.h may have C++17 compatibility issues with some CUDA versions;
|
||||
// only include it in device code where it's guaranteed to work.
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
// CUTLASS C++ includes (CUDA device code only)
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#include <cutlass/cutlass.h>
|
||||
|
||||
Reference in New Issue
Block a user