Fix: torch::kBFloat16

This commit is contained in:
2026-05-14 10:21:10 +00:00
parent 3f62e49e6e
commit a272bc49b0

View File

@@ -20,7 +20,7 @@ torch::Tensor cutlass_nvfp4_gemm_forward(
torch::Tensor SFB,
int64_t M, int64_t N, int64_t K
) {
auto D = torch::empty({M, N}, torch::dtype(at::kBF16).device(A_packed.device()));
auto D = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(A_packed.device()));
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t cuda_stream = stream.stream();