[Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954)
Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs
This commit is contained in:
committed by
GitHub
parent
c74c913bfb
commit
8674f9880e
@@ -1,6 +1,8 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cute/tensor.hpp"
|
||||
@@ -189,8 +191,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.get());
|
||||
cutlass::Status status = gemm_op(args, workspace.get(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
@@ -178,7 +180,8 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor& out, torch::Tensor const& a,
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
TORCH_CHECK(workspace_size == 0);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
cutlass::Status status = gemm_op.run(args, stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user