Make various updates and fixes:

- Add support for legacy CUDA versions; now compatible with CUDA 12.3 and newer
- Add support for NVRTC compilation
- Other fixes and code refactoring
This commit is contained in:
Ray Wang
2025-08-02 19:52:22 -07:00
parent aff9da0aba
commit d9c363f86f
36 changed files with 592 additions and 362 deletions

View File

@@ -5,10 +5,10 @@
#include "jit/device_runtime.hpp"
#include "utils/layout.hpp"
#include "jit_kernels/impls/smxx_layout.hpp"
#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
#include "jit_kernels/impls/smxx_layout.hpp"
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME deep_gemm_cpp
@@ -17,8 +17,8 @@
namespace deep_gemm {
torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const std::optional<int>& num_groups,
const std::tuple<int, int, int>& recipe,
const std::optional<int>& num_groups,
const bool& is_sfa,
const bool& disable_ue8m0_cast) {
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
@@ -121,8 +121,8 @@ void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, std::nullopt, recipe.value(), false, disable_ue8m0_cast);
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
@@ -133,7 +133,7 @@ void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
@@ -208,8 +208,8 @@ void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tens
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
@@ -223,7 +223,7 @@ void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tens
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
num_groups, m, n, k, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
@@ -271,8 +271,8 @@ void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>&
// Transform scaling factors
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, num_groups, recipe.value(), true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
@@ -286,7 +286,7 @@ void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>&
sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported kernel or scaling factor types");
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
@@ -339,18 +339,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepGEMM C++ library";
// Runtime
m.def("set_num_sms", [&](const int& new_num_sms) {
device_runtime->set_num_sms(new_num_sms);
});
m.def("get_num_sms", [&]() {
return device_runtime->get_num_sms();
});
m.def("set_num_sms", [&](const int& new_num_sms) {
device_runtime->set_num_sms(new_num_sms);
m.def("set_tc_util", [&](const int& new_tc_util) {
device_runtime->set_tc_util(new_tc_util);
});
m.def("get_tc_util", [&]() {
return device_runtime->get_tc_util();
});
// JIT
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) {
DG_HOST_ASSERT(get_env("DG_JIT_USE_NVRTC", 0) == 0 and "Currently only support NVCC");
compiler = std::make_shared<NVCCCompiler>(library_root_path, cuda_home_path_by_torch);
KernelRuntime::set_cuda_home(cuda_home_path_by_torch);
Compiler::prepare_init(library_root_path, cuda_home_path_by_torch);
KernelRuntime::prepare_init(cuda_home_path_by_torch);
});
// Stable kernel APIs with automatic arch/layout dispatch
@@ -391,7 +396,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout);
// Layout kernels
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
py::arg("disable_ue8m0_cast") = false);
// Raw kernels or functions
m.def("get_tma_aligned_size", &get_tma_aligned_size);