Make various updates and fixes: (#164)
- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -10,7 +11,7 @@ class DGException final : public std::exception {
|
||||
|
||||
public:
|
||||
explicit DGException(const char *name, const char* file, const int line, const std::string& error) {
|
||||
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
|
||||
message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error;
|
||||
}
|
||||
|
||||
const char *what() const noexcept override {
|
||||
@@ -50,7 +51,11 @@ do { \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != CUDA_SUCCESS) { \
|
||||
throw DGException("CUDA driver", __FILE__, __LINE__, ""); \
|
||||
std::stringstream ss; \
|
||||
const char *name, *info; \
|
||||
cuGetErrorName(e, &name), cuGetErrorString(e, &info); \
|
||||
ss << static_cast<int>(e) << " (" << name << ", " << info << ")"; \
|
||||
throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
@@ -60,7 +65,9 @@ do { \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != cudaSuccess) { \
|
||||
throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast<int>(e))); \
|
||||
std::stringstream ss; \
|
||||
ss << static_cast<int>(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \
|
||||
throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user