diff --git a/dsv4/kernels/attention/fmha_sm100.cuh b/dsv4/kernels/attention/fmha_sm100.cuh index 8a01b707..f64cad8e 100644 --- a/dsv4/kernels/attention/fmha_sm100.cuh +++ b/dsv4/kernels/attention/fmha_sm100.cuh @@ -30,9 +30,15 @@ #pragma once #include -#include #include +// __nv_bfloat16 is a built-in type on CUDA 13+ +// cuda_bf16.h may have C++17 compatibility issues with some CUDA versions; +// only include it in device code where it's guaranteed to work. +#if defined(__CUDA_ARCH__) +#include +#endif + // CUTLASS C++ includes (CUDA device code only) #if defined(__CUDA_ARCH__) #include