diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp index 6d8b53c7..c7549088 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp @@ -20,7 +20,7 @@ torch::Tensor cutlass_nvfp4_gemm_forward( torch::Tensor SFB, int64_t M, int64_t N, int64_t K ) { - auto D = torch::empty({M, N}, torch::dtype(at::kBF16).device(A_packed.device())); + auto D = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(A_packed.device())); auto stream = c10::cuda::getCurrentCUDAStream(); cudaStream_t cuda_stream = stream.stream();