Add SM90 kernels (#200)
This commit is contained in:
@@ -4,6 +4,9 @@
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/smxx_fp8_mqa_logits.hpp"
|
||||
#include "../jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp"
|
||||
#include "../jit_kernels/impls/smxx_clean_logits.hpp"
|
||||
|
||||
#include "layout.hpp"
|
||||
|
||||
@@ -66,12 +69,172 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
|
||||
}
|
||||
}
|
||||
|
||||
static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& kv,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& cu_seq_len_k_start,
|
||||
const torch::Tensor& cu_seq_len_k_end,
|
||||
const bool& clean_logits) {
|
||||
const auto& [seq_len, num_heads, head_dim] = get_shape<3>(q);
|
||||
const auto& [seq_len_kv, head_dim_] = get_shape<2>(kv.first);
|
||||
const auto& [seq_len_, num_heads_] = get_shape<2>(weights);
|
||||
const auto& [seq_len_kv_] = get_shape<1>(kv.second);
|
||||
|
||||
DG_HOST_ASSERT(seq_len == seq_len_);
|
||||
DG_HOST_ASSERT(num_heads == num_heads_ and head_dim == head_dim_);
|
||||
DG_HOST_ASSERT(seq_len_kv == seq_len_kv_);
|
||||
DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len);
|
||||
DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len);
|
||||
|
||||
DG_HOST_ASSERT(q.is_contiguous() and kv.first.is_contiguous());
|
||||
DG_HOST_ASSERT(kv.second.is_contiguous());
|
||||
DG_HOST_ASSERT(weights.is_contiguous());
|
||||
DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous());
|
||||
DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous());
|
||||
|
||||
DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(kv.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(kv.second.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt);
|
||||
|
||||
constexpr int seq_len_alignment = 4;
|
||||
constexpr int block_kv = 256;
|
||||
const auto aligned_seq_len = align(seq_len, seq_len_alignment);
|
||||
const auto aligned_seq_len_kv = align(seq_len_kv + block_kv, 4);
|
||||
auto logits = torch::empty({aligned_seq_len, aligned_seq_len_kv}, q.options().dtype(torch::kFloat));
|
||||
logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)});
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits,
|
||||
seq_len, seq_len_kv, aligned_seq_len_kv, num_heads, head_dim, seq_len_alignment);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
|
||||
// Clean unfilled logits
|
||||
if (clean_logits)
|
||||
smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, aligned_seq_len_kv);
|
||||
return logits;
|
||||
}
|
||||
|
||||
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) {
|
||||
const auto& [batch_size] = get_shape<1>(context_lens);
|
||||
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(context_lens.is_contiguous());
|
||||
|
||||
auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options());
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, block_kv, num_sms);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
|
||||
return schedule_metadata;
|
||||
}
|
||||
|
||||
static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& fused_kv_cache,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& context_lens,
|
||||
const torch::Tensor& block_table,
|
||||
const torch::Tensor& schedule_meta,
|
||||
const int& max_context_len,
|
||||
const bool& clean_logits) {
|
||||
const auto& [batch_size, next_n, num_heads, head_dim] = get_shape<4>(q);
|
||||
const auto& [num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf] = get_shape<4>(fused_kv_cache);
|
||||
const auto& [batch_size_] = get_shape<1>(context_lens);
|
||||
const auto& [batch_size_next_n, num_heads_] = get_shape<2>(weights);
|
||||
const auto& [batch_size__, max_block_len] = get_shape<2>(block_table);
|
||||
const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta);
|
||||
const auto& num_sms = device_runtime->get_num_sms();
|
||||
const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0);
|
||||
const auto& block_table_stride = block_table.stride(0);
|
||||
|
||||
DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size__);
|
||||
DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n);
|
||||
DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1);
|
||||
DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast<int>(sizeof(float)));
|
||||
DG_HOST_ASSERT(schedule_meta_size == num_sms + 1 and meta_info_size == 2);
|
||||
|
||||
DG_HOST_ASSERT(next_n == 1 or next_n == 2);
|
||||
DG_HOST_ASSERT(block_kv == 64);
|
||||
|
||||
DG_HOST_ASSERT(q.is_contiguous());
|
||||
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0);
|
||||
DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf);
|
||||
DG_HOST_ASSERT(fused_kv_cache.stride(2) == head_dim_with_sf);
|
||||
DG_HOST_ASSERT(fused_kv_cache.stride(3) == 1);
|
||||
DG_HOST_ASSERT(weights.is_contiguous());
|
||||
DG_HOST_ASSERT(context_lens.is_contiguous());
|
||||
DG_HOST_ASSERT(block_table.stride(1) == 1);
|
||||
DG_HOST_ASSERT(schedule_meta.is_contiguous());
|
||||
|
||||
DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
|
||||
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt);
|
||||
|
||||
// Derive FP8 values and SF tensor from KV cache
|
||||
const auto& kv_cache = torch::from_blob(
|
||||
fused_kv_cache.data_ptr(),
|
||||
{num_kv_blocks, block_kv, head_dim},
|
||||
{kv_cache_stride_bytes, head_dim, 1},
|
||||
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn)
|
||||
);
|
||||
const auto& kv_cache_scales = torch::from_blob(
|
||||
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim,
|
||||
{num_kv_blocks, block_kv},
|
||||
{kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 1},
|
||||
torch::TensorOptions().dtype(torch::kFloat32)
|
||||
);
|
||||
|
||||
// Allocate output
|
||||
constexpr int num_math_warp_groups = 4;
|
||||
const auto& aligned_max_context_len = align(max_context_len, num_math_warp_groups * block_kv);
|
||||
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat));
|
||||
logits = logits.slice(-1, 0, max_context_len);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
|
||||
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv,
|
||||
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
|
||||
// Clean unfilled logits
|
||||
if (clean_logits)
|
||||
smxx_clean_logits(logits, std::nullopt, context_lens, next_n, batch_size * next_n, max_context_len, aligned_max_context_len);
|
||||
return logits;
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
|
||||
py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_mqa_logits", &fp8_mqa_logits,
|
||||
py::arg("q"), py::arg("kv"), py::arg("weights"),
|
||||
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
|
||||
py::arg("clean_logits") = true);
|
||||
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
|
||||
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"));
|
||||
m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits,
|
||||
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
|
||||
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
|
||||
py::arg("max_context_len"), py::arg("clean_logits") = false);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::attention
|
||||
|
||||
@@ -6,6 +6,11 @@
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
|
||||
|
||||
// Attention kernels
|
||||
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
|
||||
#include <deep_gemm/impls/smxx_clean_logits.cuh>
|
||||
|
||||
// Einsum kernels
|
||||
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
|
||||
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
|
||||
|
||||
79
csrc/jit_kernels/impls/smxx_clean_logits.hpp
Normal file
79
csrc/jit_kernels/impls/smxx_clean_logits.hpp
Normal file
@@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SMXXCleanLogitsRuntime final: public LaunchRuntime<SMXXCleanLogitsRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int next_n;
|
||||
int seq_len;
|
||||
int seq_len_kv;
|
||||
uint64_t stride_kv;
|
||||
|
||||
int* cu_seq_len_k_start;
|
||||
int* cu_seq_len_k_end;
|
||||
float* logits;
|
||||
|
||||
int block_kv;
|
||||
int num_warps;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/smxx_clean_logits.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&smxx_clean_logits<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.next_n, args.block_kv, args.num_warps);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.seq_len, args.seq_len_kv, static_cast<int64_t>(args.stride_kv),
|
||||
args.cu_seq_len_k_start, args.cu_seq_len_k_end, args.logits
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void smxx_clean_logits(const torch::Tensor& logits,
|
||||
const std::optional<torch::Tensor>& cu_seq_len_k_start,
|
||||
const torch::Tensor& cu_seq_len_k_end,
|
||||
const int& next_n,
|
||||
const int& seq_len, const int& seq_len_kv,
|
||||
const uint64_t &stride_kv) {
|
||||
const int block_kv = 8192;
|
||||
const int num_warps = 8;
|
||||
const int smem_size = block_kv * sizeof(float);
|
||||
|
||||
// Launch
|
||||
const SMXXCleanLogitsRuntime::Args& args = {
|
||||
.next_n = next_n,
|
||||
.seq_len = seq_len,
|
||||
.seq_len_kv = seq_len_kv,
|
||||
.stride_kv = stride_kv,
|
||||
.cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr<int>() : nullptr,
|
||||
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
|
||||
.logits = logits.data_ptr<float>(),
|
||||
.block_kv = block_kv,
|
||||
.num_warps = num_warps,
|
||||
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
|
||||
num_warps * 32, smem_size)
|
||||
};
|
||||
const auto& code = SMXXCleanLogitsRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("smxx_clean_logits", code);
|
||||
SMXXCleanLogitsRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
152
csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp
Normal file
152
csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp
Normal file
@@ -0,0 +1,152 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90FP8MQALogitsRuntime final: public LaunchRuntime<SM90FP8MQALogitsRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int seq_len;
|
||||
int seq_len_kv;
|
||||
int stride_kv;
|
||||
int num_heads, head_dim;
|
||||
int num_q_stages;
|
||||
int num_kv_stages;
|
||||
|
||||
int block_q;
|
||||
int block_kv;
|
||||
|
||||
int* cu_seq_len_k_start;
|
||||
int* cu_seq_len_k_end;
|
||||
float* logits;
|
||||
float softmax_scale;
|
||||
|
||||
CUtensorMap tensor_map_q;
|
||||
CUtensorMap tensor_map_kv;
|
||||
CUtensorMap tensor_map_kv_scales;
|
||||
CUtensorMap tensor_map_weights;
|
||||
|
||||
int num_specialized_threads;
|
||||
int num_math_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
// TODO: optimize performance by tuning args
|
||||
// Block sizes are fixed in this kernel
|
||||
DG_HOST_ASSERT(128 % args.num_heads == 0);
|
||||
const auto& arch = device_runtime->get_arch(true);
|
||||
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm{}_fp8_mqa_logits.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_mqa_logits<
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)", arch, arch,
|
||||
args.num_heads, args.head_dim,
|
||||
args.block_q, args.block_kv,
|
||||
args.num_q_stages, args.num_kv_stages,
|
||||
args.num_specialized_threads, args.num_math_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.seq_len, args.seq_len_kv, static_cast<int64_t>(args.stride_kv),
|
||||
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
|
||||
args.logits,
|
||||
args.tensor_map_q, args.tensor_map_kv,
|
||||
args.tensor_map_kv_scales, args.tensor_map_weights
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void smxx_fp8_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& kv, const torch::Tensor& kv_scales,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& cu_seq_len_k_start,
|
||||
const torch::Tensor& cu_seq_len_k_end,
|
||||
const torch::Tensor& logits,
|
||||
const int& seq_len, const int& seq_len_kv, const int& stride_kv,
|
||||
const int& num_heads, const int& head_dim,
|
||||
const int& seq_len_alignment) {
|
||||
constexpr int block_qh = 128;
|
||||
constexpr int block_kv = 256;
|
||||
constexpr int num_specialized_threads = 128;
|
||||
constexpr int num_math_threads = 512;
|
||||
constexpr int num_q_stages = 3, num_kv_stages = 3;
|
||||
const int block_q = block_qh / num_heads;
|
||||
DG_HOST_ASSERT(block_qh % num_heads == 0);
|
||||
DG_HOST_ASSERT(seq_len_alignment % block_q == 0);
|
||||
|
||||
// Construct TMAs
|
||||
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
|
||||
const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
|
||||
head_dim, block_qh, head_dim, head_dim);
|
||||
const auto& tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
|
||||
head_dim, block_kv, head_dim, head_dim);
|
||||
// According to the driver API, the minimal alignment is 256 bytes
|
||||
// So it is safe for us to do a 16-byte OOB
|
||||
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_scales,
|
||||
get_tma_aligned_size(seq_len_kv, static_cast<int>(kv_scales.element_size())),
|
||||
1, block_kv, 1, 0, 0);
|
||||
const auto& tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
|
||||
num_heads, block_q, num_heads, 0);
|
||||
|
||||
// Calculate shared memory size
|
||||
int smem_size = 0;
|
||||
const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast<int>(q.element_size());
|
||||
const int smem_weight_size_per_stage = block_q * num_heads * static_cast<int>(weights.element_size());
|
||||
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv.element_size());
|
||||
const int kv_scale_size_per_stage = block_kv * static_cast<int>(kv_scales.element_size());
|
||||
smem_size += num_q_stages * smem_q_size_per_stage;
|
||||
smem_size += num_kv_stages * smem_kv_size_per_stage;
|
||||
smem_size += num_q_stages * smem_weight_size_per_stage;
|
||||
smem_size += num_kv_stages * kv_scale_size_per_stage;
|
||||
smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8;
|
||||
smem_size += 4;
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
|
||||
// Launch
|
||||
const SM90FP8MQALogitsRuntime::Args& args = {
|
||||
.seq_len = seq_len,
|
||||
.seq_len_kv = seq_len_kv,
|
||||
.stride_kv = stride_kv,
|
||||
.num_heads = num_heads, .head_dim = head_dim,
|
||||
.num_q_stages = num_q_stages,
|
||||
.num_kv_stages = num_kv_stages,
|
||||
.block_q = block_q,
|
||||
.block_kv = block_kv,
|
||||
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
|
||||
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
|
||||
.logits = logits.data_ptr<float>(),
|
||||
.tensor_map_q = tensor_map_q,
|
||||
.tensor_map_kv = tensor_map_kv,
|
||||
.tensor_map_kv_scales = tensor_map_kv_scales,
|
||||
.tensor_map_weights = tensor_map_weights,
|
||||
.num_specialized_threads = num_specialized_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
|
||||
num_specialized_threads + num_math_threads,
|
||||
smem_size)
|
||||
};
|
||||
const auto& code = SM90FP8MQALogitsRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_mqa_logits", code);
|
||||
SM90FP8MQALogitsRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
228
csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp
Normal file
228
csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp
Normal file
@@ -0,0 +1,228 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQALogitsMetadataRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int aligned_batch_size;
|
||||
int split_kv;
|
||||
int num_sms;
|
||||
|
||||
int batch_size;
|
||||
int* context_lens;
|
||||
int* schedule_metadata;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
const auto& arch = device_runtime->get_arch(true);
|
||||
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm{}_fp8_paged_mqa_logits.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&smxx_paged_mqa_logits_metadata<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", arch, args.aligned_batch_size, args.split_kv, args.num_sms);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.batch_size,
|
||||
args.context_lens,
|
||||
args.schedule_metadata
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
|
||||
const torch::Tensor& schedule_metadata,
|
||||
const int& batch_size, const int& block_kv, const int& num_sms) {
|
||||
constexpr int num_math_warpgroups = 4;
|
||||
constexpr int num_threads = 32;
|
||||
const int aligned_batch_size = align(batch_size, 32);
|
||||
const int split_kv = block_kv * num_math_warpgroups;
|
||||
|
||||
// Calculate shared memory size
|
||||
const int smem_size = aligned_batch_size * static_cast<int>(sizeof(int));
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
|
||||
// Launch
|
||||
const SMXXPagedMQALogitsMetadataRuntime::Args& args = {
|
||||
.aligned_batch_size = aligned_batch_size,
|
||||
.split_kv = split_kv,
|
||||
.num_sms = num_sms,
|
||||
.batch_size = batch_size,
|
||||
.context_lens = context_lens.data_ptr<int>(),
|
||||
.schedule_metadata = schedule_metadata.data_ptr<int>(),
|
||||
.launch_args = LaunchArgs(1, num_threads, smem_size)
|
||||
};
|
||||
const auto& code = SMXXPagedMQALogitsMetadataRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("smxx_paged_mqa_logits_metadata", code);
|
||||
SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime<SMXXFP8PagedMQALogitsRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int batch_size;
|
||||
int next_n;
|
||||
int num_heads;
|
||||
int head_dim;
|
||||
int block_kv;
|
||||
int block_table_stride;
|
||||
int logits_stride;
|
||||
|
||||
int num_q_stages;
|
||||
int num_kv_stages;
|
||||
int split_kv;
|
||||
|
||||
int* context_lens;
|
||||
float* logits;
|
||||
int* block_table;
|
||||
int* schedule_meta;
|
||||
|
||||
CUtensorMap tensor_map_q;
|
||||
CUtensorMap tensor_map_kv;
|
||||
CUtensorMap tensor_map_kv_scales;
|
||||
CUtensorMap tensor_map_weights;
|
||||
|
||||
int num_specialized_threads;
|
||||
int num_math_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
// TODO: optimize performance by tuning args
|
||||
// Block sizes are fixed in this kernel
|
||||
DG_HOST_ASSERT(128 % args.num_heads == 0);
|
||||
const auto& arch = device_runtime->get_arch(true);
|
||||
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm{}_fp8_paged_mqa_logits.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_paged_mqa_logits<
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)", arch, arch,
|
||||
args.next_n, args.num_heads,
|
||||
args.head_dim, args.block_kv,
|
||||
args.num_q_stages, args.num_kv_stages,
|
||||
args.split_kv,
|
||||
args.num_specialized_threads, args.num_math_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.batch_size,
|
||||
static_cast<uint64_t>(args.logits_stride),
|
||||
static_cast<uint64_t>(args.block_table_stride),
|
||||
args.context_lens, args.logits,
|
||||
args.block_table, args.schedule_meta,
|
||||
args.tensor_map_q, args.tensor_map_kv,
|
||||
args.tensor_map_kv_scales, args.tensor_map_weights
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& kv_cache,
|
||||
const torch::Tensor& kv_cache_scales,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& context_lens,
|
||||
const torch::Tensor& logits,
|
||||
const torch::Tensor& block_table,
|
||||
const torch::Tensor& schedule_meta,
|
||||
const int& batch_size, const int& next_n,
|
||||
const int& num_heads, const int& head_dim,
|
||||
const int& num_kv_blocks, const int& block_kv,
|
||||
const int& kv_cache_stride_bytes,
|
||||
const int& logits_stride,
|
||||
const int& block_table_stride,
|
||||
const int& num_sms,
|
||||
const int& num_math_warp_groups) {
|
||||
const int num_specialized_threads = 128;
|
||||
const int num_math_threads = num_math_warp_groups * 128;
|
||||
const int num_q_stages = 3, num_kv_stages = 3;
|
||||
const int split_kv = num_math_warp_groups * block_kv;
|
||||
DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0);
|
||||
|
||||
// Construct TMAs
|
||||
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
|
||||
const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
|
||||
head_dim, next_n * num_heads, head_dim, head_dim);
|
||||
const auto& tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks,
|
||||
head_dim, block_kv, 1,
|
||||
head_dim, kv_cache_stride_bytes, head_dim);
|
||||
// TODO: use 1D TMA
|
||||
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks,
|
||||
block_kv, 1, kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 0);
|
||||
const auto& tensor_map_weights = make_tma_2d_desc(weights, next_n * num_heads, batch_size,
|
||||
next_n * num_heads, 1, next_n * num_heads, 0);
|
||||
|
||||
// Calculate shared memory size
|
||||
const int swizzle_alignment = head_dim * 8;
|
||||
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
|
||||
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
|
||||
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
|
||||
|
||||
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
|
||||
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
|
||||
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
|
||||
|
||||
const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size;
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
|
||||
// Launch
|
||||
const SMXXFP8PagedMQALogitsRuntime::Args& args = {
|
||||
.batch_size = batch_size,
|
||||
.next_n = next_n,
|
||||
.num_heads = num_heads,
|
||||
.head_dim = head_dim,
|
||||
.block_kv = block_kv,
|
||||
.block_table_stride = block_table_stride,
|
||||
.logits_stride = logits_stride,
|
||||
.num_q_stages = num_q_stages,
|
||||
.num_kv_stages = num_kv_stages,
|
||||
.split_kv = split_kv,
|
||||
.context_lens = context_lens.data_ptr<int>(),
|
||||
.logits = logits.data_ptr<float>(),
|
||||
.block_table = block_table.data_ptr<int>(),
|
||||
.schedule_meta = schedule_meta.data_ptr<int>(),
|
||||
.tensor_map_q = tensor_map_q,
|
||||
.tensor_map_kv = tensor_map_kv,
|
||||
.tensor_map_kv_scales = tensor_map_kv_scales,
|
||||
.tensor_map_weights = tensor_map_weights,
|
||||
.num_specialized_threads = num_specialized_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(num_sms,
|
||||
num_specialized_threads + num_math_threads,
|
||||
smem_size)
|
||||
};
|
||||
const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_paged_mqa_logits", code);
|
||||
SMXXFP8PagedMQALogitsRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
Reference in New Issue
Block a user