153 lines
6.1 KiB
C++
153 lines
6.1 KiB
C++
#pragma once
|
|
|
|
#include "../../jit/compiler.hpp"
|
|
#include "../../jit/device_runtime.hpp"
|
|
#include "../../jit/kernel_runtime.hpp"
|
|
#include "../heuristics/sm90.hpp"
|
|
#include "../heuristics/sm100.hpp"
|
|
#include "runtime_utils.hpp"
|
|
|
|
namespace deep_gemm {
|
|
|
|
class SM90FP8MQALogitsRuntime final: public LaunchRuntime<SM90FP8MQALogitsRuntime> {
|
|
public:
|
|
struct Args {
|
|
int seq_len;
|
|
int seq_len_kv;
|
|
int stride_kv;
|
|
int num_heads, head_dim;
|
|
int num_q_stages;
|
|
int num_kv_stages;
|
|
|
|
int block_q;
|
|
int block_kv;
|
|
|
|
int* cu_seq_len_k_start;
|
|
int* cu_seq_len_k_end;
|
|
float* logits;
|
|
float softmax_scale;
|
|
|
|
CUtensorMap tensor_map_q;
|
|
CUtensorMap tensor_map_kv;
|
|
CUtensorMap tensor_map_kv_scales;
|
|
CUtensorMap tensor_map_weights;
|
|
|
|
int num_specialized_threads;
|
|
int num_math_threads;
|
|
|
|
LaunchArgs launch_args;
|
|
};
|
|
|
|
static std::string generate_impl(const Args& args) {
|
|
// TODO: optimize performance by tuning args
|
|
// Block sizes are fixed in this kernel
|
|
DG_HOST_ASSERT(128 % args.num_heads == 0);
|
|
const auto& arch = device_runtime->get_arch(true);
|
|
|
|
return fmt::format(R"(
|
|
#include <deep_gemm/impls/sm{}_fp8_mqa_logits.cuh>
|
|
|
|
using namespace deep_gemm;
|
|
|
|
static void __instantiate_kernel() {{
|
|
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_mqa_logits<
|
|
{}, {},
|
|
{}, {},
|
|
{}, {},
|
|
{}, {}
|
|
>);
|
|
}};
|
|
)", arch, arch,
|
|
args.num_heads, args.head_dim,
|
|
args.block_q, args.block_kv,
|
|
args.num_q_stages, args.num_kv_stages,
|
|
args.num_specialized_threads, args.num_math_threads);
|
|
}
|
|
|
|
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
|
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
|
args.seq_len, args.seq_len_kv, static_cast<int64_t>(args.stride_kv),
|
|
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
|
|
args.logits,
|
|
args.tensor_map_q, args.tensor_map_kv,
|
|
args.tensor_map_kv_scales, args.tensor_map_weights
|
|
));
|
|
}
|
|
};
|
|
|
|
static void smxx_fp8_mqa_logits(const torch::Tensor& q,
|
|
const torch::Tensor& kv, const torch::Tensor& kv_scales,
|
|
const torch::Tensor& weights,
|
|
const torch::Tensor& cu_seq_len_k_start,
|
|
const torch::Tensor& cu_seq_len_k_end,
|
|
const torch::Tensor& logits,
|
|
const int& seq_len, const int& seq_len_kv, const int& stride_kv,
|
|
const int& num_heads, const int& head_dim,
|
|
const int& seq_len_alignment) {
|
|
constexpr int block_qh = 128;
|
|
constexpr int block_kv = 256;
|
|
constexpr int num_specialized_threads = 128;
|
|
constexpr int num_math_threads = 512;
|
|
constexpr int num_q_stages = 3, num_kv_stages = 3;
|
|
const int block_q = block_qh / num_heads;
|
|
DG_HOST_ASSERT(block_qh % num_heads == 0);
|
|
DG_HOST_ASSERT(seq_len_alignment % block_q == 0);
|
|
|
|
// Construct TMAs
|
|
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
|
|
const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
|
|
head_dim, block_qh, head_dim, head_dim);
|
|
const auto& tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
|
|
head_dim, block_kv, head_dim, head_dim);
|
|
// According to the driver API, the minimal alignment is 256 bytes
|
|
// So it is safe for us to do a 16-byte OOB
|
|
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_scales,
|
|
get_tma_aligned_size(seq_len_kv, static_cast<int>(kv_scales.element_size())),
|
|
1, block_kv, 1, 0, 0);
|
|
const auto& tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
|
|
num_heads, block_q, num_heads, 0);
|
|
|
|
// Calculate shared memory size
|
|
int smem_size = 0;
|
|
const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast<int>(q.element_size());
|
|
const int smem_weight_size_per_stage = block_q * num_heads * static_cast<int>(weights.element_size());
|
|
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv.element_size());
|
|
const int kv_scale_size_per_stage = block_kv * static_cast<int>(kv_scales.element_size());
|
|
smem_size += num_q_stages * smem_q_size_per_stage;
|
|
smem_size += num_kv_stages * smem_kv_size_per_stage;
|
|
smem_size += num_q_stages * smem_weight_size_per_stage;
|
|
smem_size += num_kv_stages * kv_scale_size_per_stage;
|
|
smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8;
|
|
smem_size += 4;
|
|
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
|
|
|
// Launch
|
|
const SM90FP8MQALogitsRuntime::Args& args = {
|
|
.seq_len = seq_len,
|
|
.seq_len_kv = seq_len_kv,
|
|
.stride_kv = stride_kv,
|
|
.num_heads = num_heads, .head_dim = head_dim,
|
|
.num_q_stages = num_q_stages,
|
|
.num_kv_stages = num_kv_stages,
|
|
.block_q = block_q,
|
|
.block_kv = block_kv,
|
|
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
|
|
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
|
|
.logits = logits.data_ptr<float>(),
|
|
.tensor_map_q = tensor_map_q,
|
|
.tensor_map_kv = tensor_map_kv,
|
|
.tensor_map_kv_scales = tensor_map_kv_scales,
|
|
.tensor_map_weights = tensor_map_weights,
|
|
.num_specialized_threads = num_specialized_threads,
|
|
.num_math_threads = num_math_threads,
|
|
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
|
|
num_specialized_threads + num_math_threads,
|
|
smem_size)
|
|
};
|
|
const auto& code = SM90FP8MQALogitsRuntime::generate(args);
|
|
const auto& runtime = compiler->build("sm90_fp8_mqa_logits", code);
|
|
SM90FP8MQALogitsRuntime::launch(runtime, args);
|
|
}
|
|
|
|
} // namespace deep_gemm
|