Add more GPU architectures support (#112)
* Add more GPU architectures support * Update layout.py * Optimize performance, Add SM90 support, Add 1D2D SM100 support * Add fmtlib submodule at commit 553ec11 --------- Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
173
csrc/jit_kernels/impls/runtime_utils.hpp
Normal file
173
csrc/jit_kernels/impls/runtime_utils.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static std::pair<int, int> get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) {
|
||||
return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k);
|
||||
}
|
||||
|
||||
static int get_non_contiguous_dim(const cute::UMMA::Major& major) {
|
||||
return major == cute::UMMA::Major::K ? -2 : -1;
|
||||
}
|
||||
|
||||
static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) {
|
||||
for (const char& c: compiled_dims) {
|
||||
if (name == c)
|
||||
return dim;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string to_string(const cute::UMMA::Major& major) {
|
||||
switch (major) {
|
||||
case cute::UMMA::Major::K: return "cute::UMMA::Major::K";
|
||||
case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN";
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unknown major");
|
||||
}
|
||||
|
||||
static std::string to_string(const GemmType& type) {
|
||||
switch (type) {
|
||||
case GemmType::Normal: return "GemmType::Normal";
|
||||
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
|
||||
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
|
||||
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unknown GEMM type");
|
||||
}
|
||||
|
||||
static std::string to_string(const at::ScalarType& dtype) {
|
||||
switch (dtype) {
|
||||
case torch::kInt: return "int";
|
||||
case torch::kFloat: return "float";
|
||||
case torch::kBFloat16: return "cutlass::bfloat16_t";
|
||||
default: DG_HOST_UNREACHABLE("Unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
|
||||
switch (dtype) {
|
||||
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
default: DG_HOST_UNREACHABLE("Unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
|
||||
switch (mode) {
|
||||
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
case 128: return CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
default: DG_HOST_UNREACHABLE("Unsupported swizzling mode");
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
int gmem_inner_dim, int gmem_outer_dim,
|
||||
int smem_inner_dim, int smem_outer_dim,
|
||||
const int& gmem_outer_stride,
|
||||
const int& swizzle_mode) {
|
||||
const auto& elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
smem_inner_dim = swizzle_mode / elem_size;
|
||||
|
||||
CUtensorMap tensor_map;
|
||||
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
|
||||
const cuuint32_t smem_dims[2] = {static_cast<cuuint32_t>(smem_inner_dim), static_cast<cuuint32_t>(smem_outer_dim)};
|
||||
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
|
||||
const cuuint32_t elem_strides[2] = {1, 1};
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
|
||||
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
|
||||
gmem_outer_stride, swizzle_mode, elem_size);
|
||||
}
|
||||
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
|
||||
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
|
||||
const torch::Tensor& t,
|
||||
const int& shape_m, const int& shape_k,
|
||||
const int& block_m, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
if (num_groups > 1)
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
|
||||
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m);
|
||||
return make_tma_2d_desc(t,
|
||||
gmem_inner_dim, gmem_outer_dim,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
const torch::Tensor& t,
|
||||
const int& shape_n, const int& shape_k,
|
||||
const int& block_n, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
|
||||
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
|
||||
|
||||
// `num_groups` is always applied into the outer dimensions
|
||||
return make_tma_2d_desc(t,
|
||||
gmem_inner_dim, gmem_outer_dim * num_groups,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
const int& shape_m, const int& shape_n,
|
||||
const int& block_m, const int& block_n,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
|
||||
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
|
||||
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
|
||||
return make_tma_2d_desc(t,
|
||||
shape_n, shape_m * num_groups,
|
||||
block_n, block_m,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
const torch::Tensor& t,
|
||||
int shape_mn, int shape_k,
|
||||
const int& block_mn, const int& block_k,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
|
||||
|
||||
// TODO: maybe swizzle SF as well
|
||||
DG_HOST_ASSERT(swizzle_mode == 0);
|
||||
|
||||
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
|
||||
return make_tma_2d_desc(t,
|
||||
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
|
||||
block_mn, 1,
|
||||
shape_mn,
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
351
csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Normal file
351
csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Normal file
@@ -0,0 +1,351 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void* grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
CUtensorMap tensor_map_sfb;
|
||||
CUtensorMap tensor_map_c;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type),
|
||||
args.gemm_config.with_accumulation,
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
args.grouped_layout, args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_sfa, args.tensor_map_sfb,
|
||||
args.tensor_map_c, args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& cd = c.value_or(d);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(cd.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
if (c.has_value()) {
|
||||
if (c->data_ptr() == d.data_ptr()) {
|
||||
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
} else {
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
d.copy_(c.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_c = tensor_map_c,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_c = tensor_map_d,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::Kernel1D1D,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_c = tensor_map_d,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n,
|
||||
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
|
||||
|
||||
int sum_k = 0, sum_sf_k = 0;
|
||||
for (const auto& k: ks) {
|
||||
sum_k += k, sum_sf_k += ceil_div(k, 512);
|
||||
DG_HOST_ASSERT(k % 128 == 0);
|
||||
}
|
||||
const auto& num_groups = static_cast<int>(ks.size());
|
||||
|
||||
// Get config using max K for better performance
|
||||
const auto& max_k = *std::ranges::max_element(ks);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto& cd = c.value_or(d);
|
||||
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(0)), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(0)), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(1)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(cd.stride(1)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr());
|
||||
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
}
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = ks_tensor.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_c = tensor_map_c,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
242
csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp
Normal file
242
csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp
Normal file
@@ -0,0 +1,242 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100FP8Gemm1D2DRuntime final: public LaunchRuntime<SM100FP8Gemm1D2DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *sfb, *grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d2d_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
args.sfb, args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_d, args.tensor_map_sfa));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D2D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code);
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code);
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code);
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
255
csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
Normal file
255
csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
Normal file
@@ -0,0 +1,255 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *sfb, *grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
// TODO: add CD dtype
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.num_groups,
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
args.sfb, args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_d, args.tensor_map_sfa));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D2D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
199
csrc/jit_kernels/impls/smxx_layout.hpp
Normal file
199
csrc/jit_kernels/impls/smxx_layout.hpp
Normal file
@@ -0,0 +1,199 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/layout.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
int mn, sf_k;
|
||||
int block_mn;
|
||||
void *sf, *out;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&transpose_and_pack_fp32_into_ue8m0<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, args.sf, args.out, static_cast<uint32_t>(args.mn)));
|
||||
}
|
||||
};
|
||||
|
||||
class PackFP32IntoUE8M0Runtime final: public LaunchRuntime<PackFP32IntoUE8M0Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
int num_groups, mn, sf_k, packed_sf_k;
|
||||
int block_mn, block_packed_sf_k;
|
||||
void *sf, *out, *ks;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&pack_fp32_into_ue8m0<
|
||||
{}, {}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k));
|
||||
}
|
||||
};
|
||||
|
||||
static std::tuple<int, int, int, int, int, torch::Tensor> preprocess_sf(const torch::Tensor& sf) {
|
||||
// NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
const auto& dim = sf.dim();
|
||||
DG_HOST_ASSERT(dim == 2 or dim == 3);
|
||||
DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat);
|
||||
const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf;
|
||||
|
||||
const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf);
|
||||
const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast<int>(sf.element_size()));
|
||||
return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf};
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
|
||||
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
|
||||
|
||||
// The last kernel already gives a column-major TMA aligned layout
|
||||
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
|
||||
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
|
||||
|
||||
// Normal layout requires transposing
|
||||
auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options());
|
||||
aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf);
|
||||
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
|
||||
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
|
||||
const auto& packed_sf_k = ceil_div(sf_k, 4);
|
||||
const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k},
|
||||
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
|
||||
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
|
||||
DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0);
|
||||
|
||||
// Launch the kernel
|
||||
if (batched_sf.is_contiguous()) {
|
||||
constexpr int block_mn = 48;
|
||||
constexpr int num_threads = 512;
|
||||
const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = {
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.block_mn = block_mn,
|
||||
.sf = batched_sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4)
|
||||
};
|
||||
|
||||
const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
|
||||
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
} else {
|
||||
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
|
||||
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
|
||||
|
||||
constexpr int block_mn = 128;
|
||||
constexpr int block_packed_sf_k = 16;
|
||||
constexpr int num_threads = 512;
|
||||
const PackFP32IntoUE8M0Runtime::Args& args = {
|
||||
.num_groups = 1,
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.packed_sf_k = packed_sf_k,
|
||||
.block_mn = block_mn,
|
||||
.block_packed_sf_k = block_packed_sf_k,
|
||||
.sf = batched_sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.ks = nullptr,
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
|
||||
};
|
||||
|
||||
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
|
||||
PackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
}
|
||||
return (dim == 2) ? out.squeeze(0) : out;
|
||||
}
|
||||
|
||||
static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::vector<int>& ks) {
|
||||
const auto& [sf_k, mn] = get_shape<2>(sf);
|
||||
const auto& num_groups = static_cast<int>(ks.size());
|
||||
|
||||
int ref_sf_k = 0, packed_sf_k = 0;
|
||||
for (const auto& k: ks)
|
||||
ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512);
|
||||
DG_HOST_ASSERT(sf.is_contiguous());
|
||||
DG_HOST_ASSERT(ref_sf_k == sf_k);
|
||||
DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0);
|
||||
|
||||
const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt));
|
||||
|
||||
constexpr int block_mn = 128;
|
||||
constexpr int block_packed_sf_k = 16;
|
||||
constexpr int num_threads = 512;
|
||||
const PackFP32IntoUE8M0Runtime::Args& args = {
|
||||
.num_groups = num_groups,
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.packed_sf_k = packed_sf_k,
|
||||
.block_mn = block_mn,
|
||||
.block_packed_sf_k = block_packed_sf_k,
|
||||
.sf = sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.ks = ks_tensor.data_ptr(),
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
|
||||
};
|
||||
|
||||
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
|
||||
PackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
Reference in New Issue
Block a user