diff --git a/dsv4/kernels/attention/fmha_sm100.cuh b/dsv4/kernels/attention/fmha_sm100.cuh index 3cd69cf0..8a01b707 100644 --- a/dsv4/kernels/attention/fmha_sm100.cuh +++ b/dsv4/kernels/attention/fmha_sm100.cuh @@ -30,16 +30,8 @@ #pragma once #include -#include - -// NOTE: cuda_bf16.h has a C++17 compatibility bug on CUDA 13.2. -// We define a minimal BF16 type using the __bf16 built-in (CUDA 13+). -// __bf16 is a built-in storage type; we wrap it for type safety. -#if defined(__CUDA_ARCH__) -// Device code: __nv_bfloat16 IS available via the built-in -#undef __BF16_COMPAT #include -#endif +#include // CUTLASS C++ includes (CUDA device code only) #if defined(__CUDA_ARCH__) diff --git a/tests/unit/test_fmha_sm100.py b/tests/unit/test_fmha_sm100.py index 3c598810..9d1623c5 100644 --- a/tests/unit/test_fmha_sm100.py +++ b/tests/unit/test_fmha_sm100.py @@ -42,6 +42,7 @@ nvcc_cmd = [ "--x", "cu", "-o", "/tmp/fmha_sm100_test.o", "--ptxas-options=-v", + "--expt-relaxed-constexpr", ] print(f"nvcc command: {' '.join(nvcc_cmd)}")