diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu index 8ef4fabe..c8a1620b 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/cutlass_nvfp4_gemm.cu @@ -43,7 +43,7 @@ #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cute/numeric/float8.hpp" +#include "cutlass/float_subbyte.h" #include "cute/layout.hpp" using namespace cute; diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp index 3bd0a8cc..93c31d6c 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp @@ -27,7 +27,7 @@ #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/detail/sm100_blockscaled_layout.hpp" -#include "cute/numeric/float8.hpp" +#include "cutlass/float_subbyte.h" #include "cute/layout.hpp" using namespace cute; diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/setup.py b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/setup.py index 90f3802a..08de1df0 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/setup.py +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/setup.py @@ -1,18 +1,17 @@ -""" -Setup script for CUTLASS NVFP4 block-scaled GEMM PyTorch extension. -""" +"""Setup script for CUTLASS NVFP4 block-scaled GEMM PyTorch extension.""" import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension -# CUTLASS include directory +# CUTLASS include directory — prefer the latest from GitHub CUTLASS_INCLUDE_DIR = os.environ.get( "CUTLASS_INCLUDE_DIR", - "/usr/local/lib/python3.12/dist-packages/tilelang/3rdparty/cutlass/include" + "/root/cutlass/include" ) if not os.path.exists(os.path.join(CUTLASS_INCLUDE_DIR, "cutlass", "cutlass.h")): for alt in [ + "/root/cutlass/include", "/usr/local/lib/python3.12/dist-packages/tilelang/3rdparty/cutlass/include", "/usr/local/include/cutlass", "/opt/cutlass/include",