Make various updates and fixes:
- Add support for legacy CUDA versions; now compatible with CUDA 12.3 and newer - Add support for NVRTC compilation - Other fixes and code refactoring
This commit is contained in:
@@ -5,10 +5,10 @@
|
||||
#include "jit/device_runtime.hpp"
|
||||
#include "utils/layout.hpp"
|
||||
|
||||
#include "jit_kernels/impls/smxx_layout.hpp"
|
||||
#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/smxx_layout.hpp"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
#define TORCH_EXTENSION_NAME deep_gemm_cpp
|
||||
@@ -17,8 +17,8 @@
|
||||
namespace deep_gemm {
|
||||
torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const std::optional<int>& num_groups,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& is_sfa,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
|
||||
@@ -121,8 +121,8 @@ void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, std::nullopt, recipe.value(), false, disable_ue8m0_cast);
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
@@ -133,7 +133,7 @@ void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,8 +208,8 @@ void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tens
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
@@ -223,7 +223,7 @@ void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tens
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,8 +271,8 @@ void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>&
|
||||
// Transform scaling factors
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, num_groups, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
@@ -286,7 +286,7 @@ void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>&
|
||||
sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported kernel or scaling factor types");
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,18 +339,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
// Runtime
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
});
|
||||
m.def("get_num_sms", [&]() {
|
||||
return device_runtime->get_num_sms();
|
||||
});
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
m.def("set_tc_util", [&](const int& new_tc_util) {
|
||||
device_runtime->set_tc_util(new_tc_util);
|
||||
});
|
||||
m.def("get_tc_util", [&]() {
|
||||
return device_runtime->get_tc_util();
|
||||
});
|
||||
|
||||
// JIT
|
||||
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) {
|
||||
DG_HOST_ASSERT(get_env("DG_JIT_USE_NVRTC", 0) == 0 and "Currently only support NVCC");
|
||||
compiler = std::make_shared<NVCCCompiler>(library_root_path, cuda_home_path_by_torch);
|
||||
KernelRuntime::set_cuda_home(cuda_home_path_by_torch);
|
||||
Compiler::prepare_init(library_root_path, cuda_home_path_by_torch);
|
||||
KernelRuntime::prepare_init(cuda_home_path_by_torch);
|
||||
});
|
||||
|
||||
// Stable kernel APIs with automatic arch/layout dispatch
|
||||
@@ -391,7 +396,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout);
|
||||
|
||||
// Layout kernels
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
|
||||
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
|
||||
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
|
||||
// Raw kernels or functions
|
||||
m.def("get_tma_aligned_size", &get_tma_aligned_size);
|
||||
|
||||
Reference in New Issue
Block a user