Files
DeepGEMM/csrc/jit_kernels/impls/smxx_layout.hpp
Ray Wang 9da4a23561 Add more GPU architectures support (#112)
* Add more GPU architectures support

* Update layout.py

* Optimize performance, Add SM90 support, Add 1D2D SM100 support

* Add fmtlib submodule at commit 553ec11

---------

Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
2025-07-18 11:32:22 +08:00

200 lines
7.4 KiB
C++

#pragma once
#include <torch/python.h>
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../../utils/layout.hpp"
namespace deep_gemm {
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
public:
struct Args {
int mn, sf_k;
int block_mn;
void *sf, *out;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&transpose_and_pack_fp32_into_ue8m0<
{}, {}, {}
>);
}};
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, args.sf, args.out, static_cast<uint32_t>(args.mn)));
}
};
class PackFP32IntoUE8M0Runtime final: public LaunchRuntime<PackFP32IntoUE8M0Runtime> {
public:
struct Args {
int num_groups, mn, sf_k, packed_sf_k;
int block_mn, block_packed_sf_k;
void *sf, *out, *ks;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&pack_fp32_into_ue8m0<
{}, {}, {}, {}
>);
}};
)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k);
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k));
}
};
static std::tuple<int, int, int, int, int, torch::Tensor> preprocess_sf(const torch::Tensor& sf) {
// NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
const auto& dim = sf.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat);
const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf;
const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf);
const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast<int>(sf.element_size()));
return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf};
}
static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
// The last kernel already gives a column-major TMA aligned layout
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
// Normal layout requires transposing
auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options());
aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf);
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
const auto& packed_sf_k = ceil_div(sf_k, 4);
const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k},
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0);
// Launch the kernel
if (batched_sf.is_contiguous()) {
constexpr int block_mn = 48;
constexpr int num_threads = 512;
const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = {
.mn = mn,
.sf_k = sf_k,
.block_mn = block_mn,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4)
};
const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
} else {
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
constexpr int num_threads = 512;
const PackFP32IntoUE8M0Runtime::Args& args = {
.num_groups = 1,
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.ks = nullptr,
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
}
return (dim == 2) ? out.squeeze(0) : out;
}
static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf,
const torch::Tensor& ks_tensor,
const std::vector<int>& ks) {
const auto& [sf_k, mn] = get_shape<2>(sf);
const auto& num_groups = static_cast<int>(ks.size());
int ref_sf_k = 0, packed_sf_k = 0;
for (const auto& k: ks)
ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512);
DG_HOST_ASSERT(sf.is_contiguous());
DG_HOST_ASSERT(ref_sf_k == sf_k);
DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0);
const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt));
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
constexpr int num_threads = 512;
const PackFP32IntoUE8M0Runtime::Args& args = {
.num_groups = num_groups,
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = sf.data_ptr(),
.out = out.data_ptr(),
.ks = ks_tensor.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
return out;
}
} // namespace deep_gemm