fix: guard cuda_bf16.h with __CUDA_ARCH__

This commit is contained in:
2026-05-28 05:11:11 +00:00
parent 5e389b5ed9
commit 8783a25deb

View File

@@ -30,9 +30,15 @@
#pragma once
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cstdint>
// __nv_bfloat16 is a built-in type on CUDA 13+
// cuda_bf16.h may have C++17 compatibility issues with some CUDA versions;
// only include it in device code where it's guaranteed to work.
#if defined(__CUDA_ARCH__)
#include <cuda_bf16.h>
#endif
// CUTLASS C++ includes (CUDA device code only)
#if defined(__CUDA_ARCH__)
#include <cutlass/cutlass.h>