#pragma once #include #include #include #include #include #include // Stable ABI equivalent of TORCH_CHECK_NOT_IMPLEMENTED. #define STD_TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \ STD_TORCH_CHECK(cond, "NotImplementedError: ", __VA_ARGS__) // Utility to get the current CUDA stream for a given device using stable APIs. // Returns a cudaStream_t for use in kernel launches. inline cudaStream_t get_current_cuda_stream(int32_t device_index = -1) { void* stream_ptr = nullptr; TORCH_ERROR_CODE_CHECK( aoti_torch_get_current_cuda_stream(device_index, &stream_ptr)); return reinterpret_cast(stream_ptr); }