23 lines
816 B
C
23 lines
816 B
C
#pragma once
|
|
|
|
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
|
#include <torch/csrc/stable/accelerator.h>
|
|
#include <torch/csrc/stable/ops.h>
|
|
#include <torch/csrc/stable/tensor.h>
|
|
#include <torch/headeronly/util/shim_utils.h>
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
// Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED.
|
|
#define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
|
|
STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__)
|
|
|
|
// Utility to get the current CUDA stream for a given device using stable APIs.
|
|
// Returns a cudaStream_t for use in kernel launches.
|
|
inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) {
|
|
void* stream_ptr = nullptr;
|
|
TORCH_ERROR_CODE_CHECK(
|
|
aoti_torch_get_current_cuda_stream(device_index, &stream_ptr));
|
|
return reinterpret_cast<cudaStream_t>(stream_ptr);
|
|
}
|