Fix PyTorch API: use c10::cuda and at::kBF16
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
/** PyTorch binding for CUTLASS NVFP4 block-scaled GEMM */
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
extern "C" int cutlass_nvfp4_gemm_run(
|
||||
const void* A_ptr, const void* SFA_ptr,
|
||||
@@ -13,19 +14,16 @@ extern "C" int cutlass_nvfp4_gemm_run(
|
||||
);
|
||||
|
||||
torch::Tensor cutlass_nvfp4_gemm_forward(
|
||||
torch::Tensor A_packed, // (M, K_half) int8 packed E2M1
|
||||
torch::Tensor SFA, // UE4M3 block scales for A
|
||||
torch::Tensor B_packed, // (N, K_half) int8 packed E2M1
|
||||
torch::Tensor SFB, // UE4M3 block scales for B
|
||||
int64_t M, int64_t N, int64_t K // K in FP4 elements (actual K = K * 2 since 2 FP4 per byte)
|
||||
torch::Tensor A_packed,
|
||||
torch::Tensor SFA,
|
||||
torch::Tensor B_packed,
|
||||
torch::Tensor SFB,
|
||||
int64_t M, int64_t N, int64_t K
|
||||
) {
|
||||
// Runtime check for SM100
|
||||
auto props = at::cuda::getCurrentDeviceProperties();
|
||||
TORCH_CHECK(props->major >= 10, "CUTLASS NVFP4 GEMM requires SM100 (Blackwell) architecture, got SM", props->major, props->minor);
|
||||
auto D = torch::empty({M, N}, torch::dtype(at::kBF16).device(A_packed.device()));
|
||||
|
||||
auto D = torch::empty({M, N}, torch::dtype(torch::kBF16).device(A_packed.device()));
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t cuda_stream = stream.stream();
|
||||
|
||||
int rc = cutlass_nvfp4_gemm_run(
|
||||
A_packed.data_ptr(), SFA.data_ptr(),
|
||||
@@ -33,7 +31,7 @@ torch::Tensor cutlass_nvfp4_gemm_forward(
|
||||
D.data_ptr(),
|
||||
static_cast<int>(M), static_cast<int>(N), static_cast<int>(K),
|
||||
1.0f, 0.0f,
|
||||
stream
|
||||
cuda_stream
|
||||
);
|
||||
|
||||
TORCH_CHECK(rc == 0, "CUTLASS NVFP4 GEMM failed with error code ", rc);
|
||||
|
||||
Reference in New Issue
Block a user