Make various updates and fixes: (#164)
- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
This commit is contained in:
143
csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
Normal file
143
csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
Normal file
@@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100BF16GemmRuntime final: public LaunchRuntime<SM100BF16GemmRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void* grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_c;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_bf16_gemm_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
|
||||
args.gemm_config.tc_util);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.grouped_layout, args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_c, args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
// TODO: test other Ks
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& cd = c.value_or(d);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(cd.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
if (c.has_value()) {
|
||||
if (c->data_ptr() == d.data_ptr()) {
|
||||
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
} else {
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
d.copy_(c.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_c = tensor_map_c,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_bf16_gemm", code);
|
||||
SM100BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -155,7 +156,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -202,7 +203,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -136,7 +137,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -179,7 +180,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
229
csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Normal file
229
csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Normal file
@@ -0,0 +1,229 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90BF16GemmRuntime final: public LaunchRuntime<SM90BF16GemmRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_bf16_gemm_impl<
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
// TODO: add CD dtype
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.num_groups,
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bf16_gemm", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::KernelNoSF,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -139,7 +140,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -185,7 +186,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
@@ -10,6 +10,35 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class TransposeFP32Runtime final: public LaunchRuntime<TransposeFP32Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
int mn, sf_k;
|
||||
int block_mn;
|
||||
void *sf, *out;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&transpose_fp32<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast<uint32_t>(args.mn)));
|
||||
}
|
||||
};
|
||||
|
||||
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
@@ -88,10 +117,32 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
|
||||
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
|
||||
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
|
||||
|
||||
// Normal layout requires transposing
|
||||
auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options());
|
||||
aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf);
|
||||
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
|
||||
const auto& out = torch::empty_strided({num_groups, mn, sf_k},
|
||||
{tma_aligned_mn * sf_k, 1, tma_aligned_mn},
|
||||
batched_sf.options());
|
||||
|
||||
if (not batched_sf.is_contiguous()) {
|
||||
// Fallback to PyTorch's slow copy if not contiguous
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
out.copy_(batched_sf);
|
||||
} else {
|
||||
constexpr int block_mn = 64;
|
||||
constexpr int num_threads = 512;
|
||||
const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast<int>(sizeof(float));
|
||||
const TransposeFP32Runtime::Args& args = {
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.block_mn = block_mn,
|
||||
.sf = batched_sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size)
|
||||
};
|
||||
|
||||
const auto& code = TransposeFP32Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("transpose_fp32", code);
|
||||
TransposeFP32Runtime::launch(runtime, args);
|
||||
}
|
||||
return (dim == 2) ? out.squeeze(0) : out;
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) {
|
||||
@@ -127,7 +178,6 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
|
||||
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
|
||||
// Launch the kernel
|
||||
if (batched_sf.is_contiguous()) {
|
||||
// Fallback to slow PyTorch impl for non-supported cases
|
||||
if ((mn * sf_k) % 4 != 0 and num_groups > 1)
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
|
||||
|
||||
@@ -146,11 +196,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
|
||||
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
|
||||
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
} else {
|
||||
// Fallback to slow PyTorch impl for non-supported cases
|
||||
if (mn % 4 != 0 or num_groups > 1)
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
|
||||
|
||||
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
|
||||
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
|
||||
|
||||
constexpr int block_mn = 128;
|
||||
|
||||
Reference in New Issue
Block a user