Fix: torch::kBFloat16
This commit is contained in:
@@ -20,7 +20,7 @@ torch::Tensor cutlass_nvfp4_gemm_forward(
|
||||
torch::Tensor SFB,
|
||||
int64_t M, int64_t N, int64_t K
|
||||
) {
|
||||
auto D = torch::empty({M, N}, torch::dtype(at::kBF16).device(A_packed.device()));
|
||||
auto D = torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(A_packed.device()));
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t cuda_stream = stream.stream();
|
||||
|
||||
Reference in New Issue
Block a user