From 3f62e49e6e02e2f2fab7fcd4be16a67e95d5c5f6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 10:20:00 +0000 Subject: [PATCH] Fix PyTorch API: use c10::cuda and at::kBF16 --- .../cutlass_nvfp4_gemm/pytorch_binding.cpp | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp index c9b3f4ba..6d8b53c7 100644 --- a/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp +++ b/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm/pytorch_binding.cpp @@ -1,7 +1,8 @@ /** PyTorch binding for CUTLASS NVFP4 block-scaled GEMM */ #include -#include +#include +#include extern "C" int cutlass_nvfp4_gemm_run( const void* A_ptr, const void* SFA_ptr, @@ -13,19 +14,16 @@ extern "C" int cutlass_nvfp4_gemm_run( ); torch::Tensor cutlass_nvfp4_gemm_forward( - torch::Tensor A_packed, // (M, K_half) int8 packed E2M1 - torch::Tensor SFA, // UE4M3 block scales for A - torch::Tensor B_packed, // (N, K_half) int8 packed E2M1 - torch::Tensor SFB, // UE4M3 block scales for B - int64_t M, int64_t N, int64_t K // K in FP4 elements (actual K = K * 2 since 2 FP4 per byte) + torch::Tensor A_packed, + torch::Tensor SFA, + torch::Tensor B_packed, + torch::Tensor SFB, + int64_t M, int64_t N, int64_t K ) { - // Runtime check for SM100 - auto props = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(props->major >= 10, "CUTLASS NVFP4 GEMM requires SM100 (Blackwell) architecture, got SM", props->major, props->minor); + auto D = torch::empty({M, N}, torch::dtype(at::kBF16).device(A_packed.device())); - auto D = torch::empty({M, N}, torch::dtype(torch::kBF16).device(A_packed.device())); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto stream = c10::cuda::getCurrentCUDAStream(); + cudaStream_t cuda_stream = stream.stream(); int rc = cutlass_nvfp4_gemm_run( A_packed.data_ptr(), SFA.data_ptr(), @@ -33,7 +31,7 @@ torch::Tensor cutlass_nvfp4_gemm_forward( D.data_ptr(), static_cast(M), static_cast(N), static_cast(K), 1.0f, 0.0f, - stream + cuda_stream ); TORCH_CHECK(rc == 0, "CUTLASS NVFP4 GEMM failed with error code ", rc);