[Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954)

Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs
This commit is contained in:
Tyler Michael Smith
2024-05-22 10:10:43 -04:00
committed by GitHub
parent c74c913bfb
commit 8674f9880e
3 changed files with 50 additions and 2 deletions

View File

@@ -1,6 +1,8 @@
#include <stddef.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
@@ -189,8 +191,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get());
cutlass::Status status = gemm_op(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}

View File

@@ -1,5 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <iostream>
#include <sstream>
#include <vector>
@@ -178,7 +180,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
size_t workspace_size = gemm_op.get_workspace_size(args);
TORCH_CHECK(workspace_size == 0);
cutlass::Status status = gemm_op.run(args);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, stream);
CUTLASS_CHECK(status);
}
} // namespace