Add SM90 kernels (#200)
This commit is contained in:
@@ -4,6 +4,9 @@
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/smxx_fp8_mqa_logits.hpp"
|
||||
#include "../jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp"
|
||||
#include "../jit_kernels/impls/smxx_clean_logits.hpp"
|
||||
|
||||
#include "layout.hpp"
|
||||
|
||||
@@ -66,12 +69,172 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
|
||||
}
|
||||
}
|
||||
|
||||
static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& kv,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& cu_seq_len_k_start,
|
||||
const torch::Tensor& cu_seq_len_k_end,
|
||||
const bool& clean_logits) {
|
||||
const auto& [seq_len, num_heads, head_dim] = get_shape<3>(q);
|
||||
const auto& [seq_len_kv, head_dim_] = get_shape<2>(kv.first);
|
||||
const auto& [seq_len_, num_heads_] = get_shape<2>(weights);
|
||||
const auto& [seq_len_kv_] = get_shape<1>(kv.second);
|
||||
|
||||
DG_HOST_ASSERT(seq_len == seq_len_);
|
||||
DG_HOST_ASSERT(num_heads == num_heads_ and head_dim == head_dim_);
|
||||
DG_HOST_ASSERT(seq_len_kv == seq_len_kv_);
|
||||
DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len);
|
||||
DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len);
|
||||
|
||||
DG_HOST_ASSERT(q.is_contiguous() and kv.first.is_contiguous());
|
||||
DG_HOST_ASSERT(kv.second.is_contiguous());
|
||||
DG_HOST_ASSERT(weights.is_contiguous());
|
||||
DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous());
|
||||
DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous());
|
||||
|
||||
DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(kv.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(kv.second.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt);
|
||||
|
||||
constexpr int seq_len_alignment = 4;
|
||||
constexpr int block_kv = 256;
|
||||
const auto aligned_seq_len = align(seq_len, seq_len_alignment);
|
||||
const auto aligned_seq_len_kv = align(seq_len_kv + block_kv, 4);
|
||||
auto logits = torch::empty({aligned_seq_len, aligned_seq_len_kv}, q.options().dtype(torch::kFloat));
|
||||
logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)});
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits,
|
||||
seq_len, seq_len_kv, aligned_seq_len_kv, num_heads, head_dim, seq_len_alignment);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
|
||||
// Clean unfilled logits
|
||||
if (clean_logits)
|
||||
smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, aligned_seq_len_kv);
|
||||
return logits;
|
||||
}
|
||||
|
||||
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) {
|
||||
const auto& [batch_size] = get_shape<1>(context_lens);
|
||||
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(context_lens.is_contiguous());
|
||||
|
||||
auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options());
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, block_kv, num_sms);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
|
||||
return schedule_metadata;
|
||||
}
|
||||
|
||||
static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& fused_kv_cache,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& context_lens,
|
||||
const torch::Tensor& block_table,
|
||||
const torch::Tensor& schedule_meta,
|
||||
const int& max_context_len,
|
||||
const bool& clean_logits) {
|
||||
const auto& [batch_size, next_n, num_heads, head_dim] = get_shape<4>(q);
|
||||
const auto& [num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf] = get_shape<4>(fused_kv_cache);
|
||||
const auto& [batch_size_] = get_shape<1>(context_lens);
|
||||
const auto& [batch_size_next_n, num_heads_] = get_shape<2>(weights);
|
||||
const auto& [batch_size__, max_block_len] = get_shape<2>(block_table);
|
||||
const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta);
|
||||
const auto& num_sms = device_runtime->get_num_sms();
|
||||
const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0);
|
||||
const auto& block_table_stride = block_table.stride(0);
|
||||
|
||||
DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size__);
|
||||
DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n);
|
||||
DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1);
|
||||
DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast<int>(sizeof(float)));
|
||||
DG_HOST_ASSERT(schedule_meta_size == num_sms + 1 and meta_info_size == 2);
|
||||
|
||||
DG_HOST_ASSERT(next_n == 1 or next_n == 2);
|
||||
DG_HOST_ASSERT(block_kv == 64);
|
||||
|
||||
DG_HOST_ASSERT(q.is_contiguous());
|
||||
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0);
|
||||
DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf);
|
||||
DG_HOST_ASSERT(fused_kv_cache.stride(2) == head_dim_with_sf);
|
||||
DG_HOST_ASSERT(fused_kv_cache.stride(3) == 1);
|
||||
DG_HOST_ASSERT(weights.is_contiguous());
|
||||
DG_HOST_ASSERT(context_lens.is_contiguous());
|
||||
DG_HOST_ASSERT(block_table.stride(1) == 1);
|
||||
DG_HOST_ASSERT(schedule_meta.is_contiguous());
|
||||
|
||||
DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
|
||||
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt);
|
||||
DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt);
|
||||
|
||||
// Derive FP8 values and SF tensor from KV cache
|
||||
const auto& kv_cache = torch::from_blob(
|
||||
fused_kv_cache.data_ptr(),
|
||||
{num_kv_blocks, block_kv, head_dim},
|
||||
{kv_cache_stride_bytes, head_dim, 1},
|
||||
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn)
|
||||
);
|
||||
const auto& kv_cache_scales = torch::from_blob(
|
||||
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim,
|
||||
{num_kv_blocks, block_kv},
|
||||
{kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 1},
|
||||
torch::TensorOptions().dtype(torch::kFloat32)
|
||||
);
|
||||
|
||||
// Allocate output
|
||||
constexpr int num_math_warp_groups = 4;
|
||||
const auto& aligned_max_context_len = align(max_context_len, num_math_warp_groups * block_kv);
|
||||
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat));
|
||||
logits = logits.slice(-1, 0, max_context_len);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
|
||||
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv,
|
||||
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
|
||||
// Clean unfilled logits
|
||||
if (clean_logits)
|
||||
smxx_clean_logits(logits, std::nullopt, context_lens, next_n, batch_size * next_n, max_context_len, aligned_max_context_len);
|
||||
return logits;
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
|
||||
py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_mqa_logits", &fp8_mqa_logits,
|
||||
py::arg("q"), py::arg("kv"), py::arg("weights"),
|
||||
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
|
||||
py::arg("clean_logits") = true);
|
||||
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
|
||||
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"));
|
||||
m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits,
|
||||
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
|
||||
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
|
||||
py::arg("max_context_len"), py::arg("clean_logits") = false);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::attention
|
||||
|
||||
@@ -6,6 +6,11 @@
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
|
||||
|
||||
// Attention kernels
|
||||
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
|
||||
#include <deep_gemm/impls/smxx_clean_logits.cuh>
|
||||
|
||||
// Einsum kernels
|
||||
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
|
||||
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
|
||||
|
||||
79
csrc/jit_kernels/impls/smxx_clean_logits.hpp
Normal file
79
csrc/jit_kernels/impls/smxx_clean_logits.hpp
Normal file
@@ -0,0 +1,79 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SMXXCleanLogitsRuntime final: public LaunchRuntime<SMXXCleanLogitsRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int next_n;
|
||||
int seq_len;
|
||||
int seq_len_kv;
|
||||
uint64_t stride_kv;
|
||||
|
||||
int* cu_seq_len_k_start;
|
||||
int* cu_seq_len_k_end;
|
||||
float* logits;
|
||||
|
||||
int block_kv;
|
||||
int num_warps;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/smxx_clean_logits.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&smxx_clean_logits<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.next_n, args.block_kv, args.num_warps);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.seq_len, args.seq_len_kv, static_cast<int64_t>(args.stride_kv),
|
||||
args.cu_seq_len_k_start, args.cu_seq_len_k_end, args.logits
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void smxx_clean_logits(const torch::Tensor& logits,
|
||||
const std::optional<torch::Tensor>& cu_seq_len_k_start,
|
||||
const torch::Tensor& cu_seq_len_k_end,
|
||||
const int& next_n,
|
||||
const int& seq_len, const int& seq_len_kv,
|
||||
const uint64_t &stride_kv) {
|
||||
const int block_kv = 8192;
|
||||
const int num_warps = 8;
|
||||
const int smem_size = block_kv * sizeof(float);
|
||||
|
||||
// Launch
|
||||
const SMXXCleanLogitsRuntime::Args& args = {
|
||||
.next_n = next_n,
|
||||
.seq_len = seq_len,
|
||||
.seq_len_kv = seq_len_kv,
|
||||
.stride_kv = stride_kv,
|
||||
.cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr<int>() : nullptr,
|
||||
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
|
||||
.logits = logits.data_ptr<float>(),
|
||||
.block_kv = block_kv,
|
||||
.num_warps = num_warps,
|
||||
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
|
||||
num_warps * 32, smem_size)
|
||||
};
|
||||
const auto& code = SMXXCleanLogitsRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("smxx_clean_logits", code);
|
||||
SMXXCleanLogitsRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
152
csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp
Normal file
152
csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp
Normal file
@@ -0,0 +1,152 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90FP8MQALogitsRuntime final: public LaunchRuntime<SM90FP8MQALogitsRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int seq_len;
|
||||
int seq_len_kv;
|
||||
int stride_kv;
|
||||
int num_heads, head_dim;
|
||||
int num_q_stages;
|
||||
int num_kv_stages;
|
||||
|
||||
int block_q;
|
||||
int block_kv;
|
||||
|
||||
int* cu_seq_len_k_start;
|
||||
int* cu_seq_len_k_end;
|
||||
float* logits;
|
||||
float softmax_scale;
|
||||
|
||||
CUtensorMap tensor_map_q;
|
||||
CUtensorMap tensor_map_kv;
|
||||
CUtensorMap tensor_map_kv_scales;
|
||||
CUtensorMap tensor_map_weights;
|
||||
|
||||
int num_specialized_threads;
|
||||
int num_math_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
// TODO: optimize performance by tuning args
|
||||
// Block sizes are fixed in this kernel
|
||||
DG_HOST_ASSERT(128 % args.num_heads == 0);
|
||||
const auto& arch = device_runtime->get_arch(true);
|
||||
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm{}_fp8_mqa_logits.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_mqa_logits<
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)", arch, arch,
|
||||
args.num_heads, args.head_dim,
|
||||
args.block_q, args.block_kv,
|
||||
args.num_q_stages, args.num_kv_stages,
|
||||
args.num_specialized_threads, args.num_math_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.seq_len, args.seq_len_kv, static_cast<int64_t>(args.stride_kv),
|
||||
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
|
||||
args.logits,
|
||||
args.tensor_map_q, args.tensor_map_kv,
|
||||
args.tensor_map_kv_scales, args.tensor_map_weights
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void smxx_fp8_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& kv, const torch::Tensor& kv_scales,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& cu_seq_len_k_start,
|
||||
const torch::Tensor& cu_seq_len_k_end,
|
||||
const torch::Tensor& logits,
|
||||
const int& seq_len, const int& seq_len_kv, const int& stride_kv,
|
||||
const int& num_heads, const int& head_dim,
|
||||
const int& seq_len_alignment) {
|
||||
constexpr int block_qh = 128;
|
||||
constexpr int block_kv = 256;
|
||||
constexpr int num_specialized_threads = 128;
|
||||
constexpr int num_math_threads = 512;
|
||||
constexpr int num_q_stages = 3, num_kv_stages = 3;
|
||||
const int block_q = block_qh / num_heads;
|
||||
DG_HOST_ASSERT(block_qh % num_heads == 0);
|
||||
DG_HOST_ASSERT(seq_len_alignment % block_q == 0);
|
||||
|
||||
// Construct TMAs
|
||||
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
|
||||
const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
|
||||
head_dim, block_qh, head_dim, head_dim);
|
||||
const auto& tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
|
||||
head_dim, block_kv, head_dim, head_dim);
|
||||
// According to the driver API, the minimal alignment is 256 bytes
|
||||
// So it is safe for us to do a 16-byte OOB
|
||||
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_scales,
|
||||
get_tma_aligned_size(seq_len_kv, static_cast<int>(kv_scales.element_size())),
|
||||
1, block_kv, 1, 0, 0);
|
||||
const auto& tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
|
||||
num_heads, block_q, num_heads, 0);
|
||||
|
||||
// Calculate shared memory size
|
||||
int smem_size = 0;
|
||||
const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast<int>(q.element_size());
|
||||
const int smem_weight_size_per_stage = block_q * num_heads * static_cast<int>(weights.element_size());
|
||||
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv.element_size());
|
||||
const int kv_scale_size_per_stage = block_kv * static_cast<int>(kv_scales.element_size());
|
||||
smem_size += num_q_stages * smem_q_size_per_stage;
|
||||
smem_size += num_kv_stages * smem_kv_size_per_stage;
|
||||
smem_size += num_q_stages * smem_weight_size_per_stage;
|
||||
smem_size += num_kv_stages * kv_scale_size_per_stage;
|
||||
smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8;
|
||||
smem_size += 4;
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
|
||||
// Launch
|
||||
const SM90FP8MQALogitsRuntime::Args& args = {
|
||||
.seq_len = seq_len,
|
||||
.seq_len_kv = seq_len_kv,
|
||||
.stride_kv = stride_kv,
|
||||
.num_heads = num_heads, .head_dim = head_dim,
|
||||
.num_q_stages = num_q_stages,
|
||||
.num_kv_stages = num_kv_stages,
|
||||
.block_q = block_q,
|
||||
.block_kv = block_kv,
|
||||
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
|
||||
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
|
||||
.logits = logits.data_ptr<float>(),
|
||||
.tensor_map_q = tensor_map_q,
|
||||
.tensor_map_kv = tensor_map_kv,
|
||||
.tensor_map_kv_scales = tensor_map_kv_scales,
|
||||
.tensor_map_weights = tensor_map_weights,
|
||||
.num_specialized_threads = num_specialized_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
|
||||
num_specialized_threads + num_math_threads,
|
||||
smem_size)
|
||||
};
|
||||
const auto& code = SM90FP8MQALogitsRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_mqa_logits", code);
|
||||
SM90FP8MQALogitsRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
228
csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp
Normal file
228
csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp
Normal file
@@ -0,0 +1,228 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQALogitsMetadataRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int aligned_batch_size;
|
||||
int split_kv;
|
||||
int num_sms;
|
||||
|
||||
int batch_size;
|
||||
int* context_lens;
|
||||
int* schedule_metadata;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
const auto& arch = device_runtime->get_arch(true);
|
||||
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm{}_fp8_paged_mqa_logits.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&smxx_paged_mqa_logits_metadata<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", arch, args.aligned_batch_size, args.split_kv, args.num_sms);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.batch_size,
|
||||
args.context_lens,
|
||||
args.schedule_metadata
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
|
||||
const torch::Tensor& schedule_metadata,
|
||||
const int& batch_size, const int& block_kv, const int& num_sms) {
|
||||
constexpr int num_math_warpgroups = 4;
|
||||
constexpr int num_threads = 32;
|
||||
const int aligned_batch_size = align(batch_size, 32);
|
||||
const int split_kv = block_kv * num_math_warpgroups;
|
||||
|
||||
// Calculate shared memory size
|
||||
const int smem_size = aligned_batch_size * static_cast<int>(sizeof(int));
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
|
||||
// Launch
|
||||
const SMXXPagedMQALogitsMetadataRuntime::Args& args = {
|
||||
.aligned_batch_size = aligned_batch_size,
|
||||
.split_kv = split_kv,
|
||||
.num_sms = num_sms,
|
||||
.batch_size = batch_size,
|
||||
.context_lens = context_lens.data_ptr<int>(),
|
||||
.schedule_metadata = schedule_metadata.data_ptr<int>(),
|
||||
.launch_args = LaunchArgs(1, num_threads, smem_size)
|
||||
};
|
||||
const auto& code = SMXXPagedMQALogitsMetadataRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("smxx_paged_mqa_logits_metadata", code);
|
||||
SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime<SMXXFP8PagedMQALogitsRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int batch_size;
|
||||
int next_n;
|
||||
int num_heads;
|
||||
int head_dim;
|
||||
int block_kv;
|
||||
int block_table_stride;
|
||||
int logits_stride;
|
||||
|
||||
int num_q_stages;
|
||||
int num_kv_stages;
|
||||
int split_kv;
|
||||
|
||||
int* context_lens;
|
||||
float* logits;
|
||||
int* block_table;
|
||||
int* schedule_meta;
|
||||
|
||||
CUtensorMap tensor_map_q;
|
||||
CUtensorMap tensor_map_kv;
|
||||
CUtensorMap tensor_map_kv_scales;
|
||||
CUtensorMap tensor_map_weights;
|
||||
|
||||
int num_specialized_threads;
|
||||
int num_math_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
// TODO: optimize performance by tuning args
|
||||
// Block sizes are fixed in this kernel
|
||||
DG_HOST_ASSERT(128 % args.num_heads == 0);
|
||||
const auto& arch = device_runtime->get_arch(true);
|
||||
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm{}_fp8_paged_mqa_logits.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_paged_mqa_logits<
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)", arch, arch,
|
||||
args.next_n, args.num_heads,
|
||||
args.head_dim, args.block_kv,
|
||||
args.num_q_stages, args.num_kv_stages,
|
||||
args.split_kv,
|
||||
args.num_specialized_threads, args.num_math_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.batch_size,
|
||||
static_cast<uint64_t>(args.logits_stride),
|
||||
static_cast<uint64_t>(args.block_table_stride),
|
||||
args.context_lens, args.logits,
|
||||
args.block_table, args.schedule_meta,
|
||||
args.tensor_map_q, args.tensor_map_kv,
|
||||
args.tensor_map_kv_scales, args.tensor_map_weights
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
|
||||
const torch::Tensor& kv_cache,
|
||||
const torch::Tensor& kv_cache_scales,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& context_lens,
|
||||
const torch::Tensor& logits,
|
||||
const torch::Tensor& block_table,
|
||||
const torch::Tensor& schedule_meta,
|
||||
const int& batch_size, const int& next_n,
|
||||
const int& num_heads, const int& head_dim,
|
||||
const int& num_kv_blocks, const int& block_kv,
|
||||
const int& kv_cache_stride_bytes,
|
||||
const int& logits_stride,
|
||||
const int& block_table_stride,
|
||||
const int& num_sms,
|
||||
const int& num_math_warp_groups) {
|
||||
const int num_specialized_threads = 128;
|
||||
const int num_math_threads = num_math_warp_groups * 128;
|
||||
const int num_q_stages = 3, num_kv_stages = 3;
|
||||
const int split_kv = num_math_warp_groups * block_kv;
|
||||
DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0);
|
||||
|
||||
// Construct TMAs
|
||||
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
|
||||
const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
|
||||
head_dim, next_n * num_heads, head_dim, head_dim);
|
||||
const auto& tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks,
|
||||
head_dim, block_kv, 1,
|
||||
head_dim, kv_cache_stride_bytes, head_dim);
|
||||
// TODO: use 1D TMA
|
||||
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks,
|
||||
block_kv, 1, kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 0);
|
||||
const auto& tensor_map_weights = make_tma_2d_desc(weights, next_n * num_heads, batch_size,
|
||||
next_n * num_heads, 1, next_n * num_heads, 0);
|
||||
|
||||
// Calculate shared memory size
|
||||
const int swizzle_alignment = head_dim * 8;
|
||||
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
|
||||
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
|
||||
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
|
||||
|
||||
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
|
||||
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
|
||||
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
|
||||
|
||||
const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size;
|
||||
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
|
||||
|
||||
// Launch
|
||||
const SMXXFP8PagedMQALogitsRuntime::Args& args = {
|
||||
.batch_size = batch_size,
|
||||
.next_n = next_n,
|
||||
.num_heads = num_heads,
|
||||
.head_dim = head_dim,
|
||||
.block_kv = block_kv,
|
||||
.block_table_stride = block_table_stride,
|
||||
.logits_stride = logits_stride,
|
||||
.num_q_stages = num_q_stages,
|
||||
.num_kv_stages = num_kv_stages,
|
||||
.split_kv = split_kv,
|
||||
.context_lens = context_lens.data_ptr<int>(),
|
||||
.logits = logits.data_ptr<float>(),
|
||||
.block_table = block_table.data_ptr<int>(),
|
||||
.schedule_meta = schedule_meta.data_ptr<int>(),
|
||||
.tensor_map_q = tensor_map_q,
|
||||
.tensor_map_kv = tensor_map_kv,
|
||||
.tensor_map_kv_scales = tensor_map_kv_scales,
|
||||
.tensor_map_weights = tensor_map_weights,
|
||||
.num_specialized_threads = num_specialized_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(num_sms,
|
||||
num_specialized_threads + num_math_threads,
|
||||
smem_size)
|
||||
};
|
||||
const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_paged_mqa_logits", code);
|
||||
SMXXFP8PagedMQALogitsRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -41,6 +41,10 @@ from deep_gemm_cpp import (
|
||||
cublaslt_gemm_tn, cublaslt_gemm_tt,
|
||||
# Einsum kernels
|
||||
einsum,
|
||||
# Attention kernels
|
||||
fp8_mqa_logits,
|
||||
get_paged_mqa_logits_metadata,
|
||||
fp8_paged_mqa_logits,
|
||||
# Layout kernels
|
||||
transform_sf_into_required_layout
|
||||
)
|
||||
|
||||
317
deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh
Normal file
317
deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh
Normal file
@@ -0,0 +1,317 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/mma_sm90_desc.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
template <uint32_t kHeadDim>
|
||||
static constexpr int to_swizzle_cute_type() {
|
||||
DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
|
||||
if constexpr (kHeadDim == 32)
|
||||
return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
|
||||
if constexpr (kHeadDim == 64)
|
||||
return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
|
||||
if constexpr (kHeadDim == 128)
|
||||
return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
|
||||
}
|
||||
|
||||
template <uint32_t kNumHeads, uint32_t kHeadDim,
|
||||
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
||||
void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv,
|
||||
uint32_t* cu_seq_len_k_start,
|
||||
uint32_t* cu_seq_len_k_end,
|
||||
float* logits,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
||||
// TODO: consider TMA multicast
|
||||
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
|
||||
// Q should be load only at once for a block
|
||||
const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
|
||||
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<BLOCK_Q * kNumHeads>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// Prefetch TMA descriptors
|
||||
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
||||
if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_q);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Shared memory configs
|
||||
// NOTES: weight may be unaligned
|
||||
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
||||
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
|
||||
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
||||
|
||||
// Align to swizzling alignment bytes
|
||||
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
|
||||
SMEM_Q_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
|
||||
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
|
||||
});
|
||||
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer +
|
||||
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer +
|
||||
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
|
||||
SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
||||
});
|
||||
|
||||
// TMA barriers
|
||||
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
||||
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
|
||||
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
|
||||
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
|
||||
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
|
||||
|
||||
// Initialize barriers
|
||||
const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
|
||||
if (is_tma_load_warp and cute::elect_one_sync()) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
||||
full_q_barriers[i]->init(1);
|
||||
empty_q_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
||||
full_kv_barriers[i]->init(1);
|
||||
empty_kv_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumTMARegisters = 32;
|
||||
constexpr uint32_t kNumMathRegisters = 112;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
|
||||
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {block_q_idx + gridDim.x, q_iter_idx + 1};
|
||||
};
|
||||
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
|
||||
uint32_t start = cute::numeric_limits<uint32_t>::max();
|
||||
uint32_t end = cute::numeric_limits<uint32_t>::min();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
|
||||
start = min(start, min(__ldg(cu_seq_len_k_start + q_idx), seq_len_kv));
|
||||
end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv));
|
||||
}
|
||||
start = start / 4 * 4;
|
||||
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
|
||||
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
|
||||
start, ceil_div(end - start, BLOCK_KV)}; // Task info
|
||||
};
|
||||
|
||||
// KV pipeline
|
||||
uint32_t num_total_kv_blocks = 0;
|
||||
const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {
|
||||
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
|
||||
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
|
||||
};
|
||||
};
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// Only the first warp remains
|
||||
if (not is_tma_load_warp)
|
||||
return;
|
||||
|
||||
// Prefetch
|
||||
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
|
||||
tma_copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
|
||||
tma_copy(&tensor_map_weights, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
|
||||
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
};
|
||||
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
|
||||
issue_tma_q(0, block_q_idx);
|
||||
|
||||
// Only the first lane persistently schedules over blocks
|
||||
if (cute::elect_one_sync()) {
|
||||
while (block_q_idx < num_q_blocks) {
|
||||
CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
||||
|
||||
// Wait Q consumer release
|
||||
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
||||
|
||||
// Issue TMA Q
|
||||
if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks)
|
||||
issue_tma_q(q_stage_idx, next_block_q_idx);
|
||||
|
||||
// Issue TMA KV
|
||||
#pragma unroll
|
||||
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
||||
// Wait consumer release
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
||||
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
||||
|
||||
// Issue TMA KV
|
||||
tma_copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
||||
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
|
||||
tma_copy(&tensor_map_kv_scales, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
||||
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
|
||||
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
||||
}
|
||||
num_total_kv_blocks += num_kv_blocks;
|
||||
|
||||
// Jump to the next block
|
||||
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto& thread_idx = threadIdx.x % kNumMathThreads;
|
||||
const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
|
||||
const auto& warpgroup_idx = warp_idx / 4;
|
||||
const auto& lane_idx = get_lane_idx();
|
||||
float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
|
||||
|
||||
const auto& warp_offset = warp_idx * 16;
|
||||
const auto& v_0_offset = lane_idx / 4 + 0;
|
||||
const auto& v_1_offset = lane_idx / 4 + 8;
|
||||
|
||||
while (block_q_idx < num_q_blocks) {
|
||||
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
|
||||
|
||||
// Wait TMA Q arrival
|
||||
full_q_barriers[q_stage_idx]->wait(q_phase);
|
||||
|
||||
// Read weights
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
||||
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
||||
}
|
||||
|
||||
// Compute over KV blocks
|
||||
#pragma unroll
|
||||
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
|
||||
// Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
|
||||
// Wait TMA KV arrival
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
|
||||
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
||||
|
||||
// Read per-KV scales
|
||||
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
|
||||
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
|
||||
|
||||
// Issue WGMMA
|
||||
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
|
||||
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
|
||||
to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
||||
auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K,
|
||||
to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Release KV empty
|
||||
empty_kv_barriers[kv_stage_idx]->arrive();
|
||||
|
||||
// Reduce over the head dim and store
|
||||
const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
|
||||
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation");
|
||||
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
|
||||
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
||||
const auto& transform = [&](const uint32_t& j) {
|
||||
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
||||
};
|
||||
|
||||
// Intra-thread reduction
|
||||
float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
|
||||
#pragma unroll
|
||||
for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < 4; k ++)
|
||||
sum[k] += transform(j * 4 + k);
|
||||
}
|
||||
float v_0 = (sum[0] + sum[1]) * scale_kv_0;
|
||||
float v_1 = (sum[2] + sum[3]) * scale_kv_1;
|
||||
|
||||
// Inter-thread reduction
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++ j) {
|
||||
const auto& offset = static_cast<int>(1u << j);
|
||||
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
|
||||
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
|
||||
}
|
||||
|
||||
// Store into the global memory
|
||||
// NOTES: we have redundant writes here, consider more carefully
|
||||
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
|
||||
logits[q_idx * stride_kv + kv_offset + v_0_offset] = v_0;
|
||||
logits[q_idx * stride_kv + kv_offset + v_1_offset] = v_1;
|
||||
}
|
||||
}
|
||||
num_total_kv_blocks += num_kv_blocks;
|
||||
|
||||
// Release Q empty
|
||||
empty_q_barriers[q_stage_idx]->arrive();
|
||||
|
||||
// Jump to the next block
|
||||
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
403
deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh
Normal file
403
deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh
Normal file
@@ -0,0 +1,403 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
|
||||
__global__ __launch_bounds__(32, 1)
|
||||
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t* context_lens, uint32_t* schedule_metadata) {
|
||||
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
|
||||
const uint32_t lane_idx = get_lane_idx();
|
||||
|
||||
uint32_t num_segs[kAlignedBatchSize / 32];
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
||||
const uint32_t& context_len = (k * 32 + lane_idx < batch_size ? __ldg(context_lens + k * 32 + lane_idx) : 0);
|
||||
num_segs[k] = ceil_div(context_len, SPLIT_KV);
|
||||
}
|
||||
|
||||
__shared__ uint32_t prefix_sum[kAlignedBatchSize];
|
||||
uint32_t sum = 0;
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
|
||||
uint32_t x = num_segs[k];
|
||||
#pragma unroll
|
||||
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
|
||||
const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset);
|
||||
x += (lane_idx >= offset ? y : 0);
|
||||
}
|
||||
x += sum;
|
||||
prefix_sum[k * 32 + lane_idx] = x;
|
||||
sum = __shfl_sync(0xffffffff, x, 31);
|
||||
}
|
||||
|
||||
const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs;
|
||||
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
|
||||
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
|
||||
uint32_t q_idx = 0;
|
||||
while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts)
|
||||
++ q_idx;
|
||||
const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]);
|
||||
__syncwarp();
|
||||
|
||||
schedule_metadata[sm_idx * 2] = q_idx;
|
||||
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t BLOCK_KV, uint32_t kNumMathWarpGroups>
|
||||
struct PagedMQALogitsScheduler {
|
||||
uint32_t batch_size;
|
||||
const uint32_t* context_lens;
|
||||
|
||||
uint32_t current_q_idx, current_kv_idx;
|
||||
uint32_t end_q_idx, end_kv_idx;
|
||||
uint32_t current_num_kv;
|
||||
|
||||
__device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx,
|
||||
const uint32_t* context_lens, const uint32_t* schedule_meta) {
|
||||
this->batch_size = batch_size;
|
||||
this->context_lens = context_lens;
|
||||
|
||||
const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
|
||||
const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
|
||||
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumMathWarpGroups;
|
||||
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumMathWarpGroups;
|
||||
|
||||
current_num_kv = current_q_idx < batch_size ? ceil_div(__ldg(this->context_lens + current_q_idx), BLOCK_KV) : 0;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) {
|
||||
q_idx = current_q_idx;
|
||||
kv_idx = current_kv_idx;
|
||||
num_kv = current_num_kv;
|
||||
|
||||
if (q_idx == end_q_idx and kv_idx == end_kv_idx)
|
||||
return false;
|
||||
|
||||
current_kv_idx += kNumMathWarpGroups;
|
||||
if (current_kv_idx >= current_num_kv) {
|
||||
++ current_q_idx;
|
||||
current_kv_idx = 0;
|
||||
current_num_kv = current_q_idx < batch_size ? ceil_div(__ldg(this->context_lens + current_q_idx), BLOCK_KV) : 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const {
|
||||
return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx;
|
||||
}
|
||||
};
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t kNextN, uint32_t kNumHeads,
|
||||
uint32_t kHeadDim, uint32_t BLOCK_KV,
|
||||
uint32_t kNumQStages, uint32_t kNumKVStages,
|
||||
uint32_t SPLIT_KV,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
||||
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
|
||||
const uint64_t logits_stride, const uint64_t block_table_stride,
|
||||
const uint32_t* context_lens, float* logits,
|
||||
const uint32_t* block_table, const uint32_t* schedule_meta,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
|
||||
// Types
|
||||
using WGMMA = typename FP8MMASelector<kNextN * kNumHeads>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const auto& warpgroup_idx = warp_idx / 4;
|
||||
const auto& lane_idx = get_lane_idx();
|
||||
|
||||
// Prefetch TMA descriptors
|
||||
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
|
||||
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
|
||||
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_q);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_weights);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Shared memory configs
|
||||
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
|
||||
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
|
||||
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
|
||||
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
|
||||
|
||||
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
|
||||
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
|
||||
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
|
||||
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
|
||||
constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
|
||||
|
||||
// Align to swizzling alignment bytes
|
||||
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
|
||||
|
||||
// Q data and barriers on shared memory
|
||||
auto smem_q = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
|
||||
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
|
||||
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
|
||||
|
||||
// Separate math warpgroups and tma load warps into KV groups
|
||||
// Each math warpgroup corresponds to a tma load warp
|
||||
const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
|
||||
|
||||
// Per group KV data and barriers on shared memory
|
||||
const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
|
||||
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
|
||||
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
|
||||
});
|
||||
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
|
||||
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
|
||||
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
|
||||
|
||||
// Initialize barriers
|
||||
if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
if (kv_group_idx == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumQStages; ++ i) {
|
||||
full_q_barriers[i]->init(1);
|
||||
empty_q_barriers[i]->init(kNumMathThreads);
|
||||
}
|
||||
}
|
||||
if (kv_group_idx < kNumMathWarpGroups) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
|
||||
full_kv_barriers[i]->init(1);
|
||||
empty_kv_barriers[i]->init(128);
|
||||
}
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumTMARegisters = 64;
|
||||
constexpr uint32_t kNumMathRegisters = 104;
|
||||
|
||||
// Scheduler
|
||||
auto scheduler = PagedMQALogitsScheduler<BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
|
||||
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
|
||||
|
||||
// Q and KV pipeline
|
||||
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
|
||||
};
|
||||
const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
|
||||
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
|
||||
};
|
||||
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
|
||||
|
||||
if (warp_idx >= kNumMathThreads / 32) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
if (kv_group_idx >= kNumMathWarpGroups)
|
||||
return;
|
||||
|
||||
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
|
||||
if (kv_group_idx == 0 and cute::elect_one_sync()) {
|
||||
tma_copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
|
||||
tma_copy(&tensor_map_weights, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, q_idx);
|
||||
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
||||
uint32_t q_idx = batch_size, kv_idx, num_kv;
|
||||
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
||||
bool fetched_next_task;
|
||||
|
||||
// Prefetch the first Q
|
||||
if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
|
||||
issue_tma_q(0, next_q_idx), q_iter_idx = 1;
|
||||
|
||||
int kv_block_idx_ptr = 32;
|
||||
uint32_t kv_block_idx_storage;
|
||||
|
||||
while (fetched_next_task) {
|
||||
// Prefetch next Q when current Q changes
|
||||
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
|
||||
q_idx = next_q_idx;
|
||||
kv_idx = next_kv_idx;
|
||||
num_kv = next_num_kv;
|
||||
|
||||
// Wait Q consumer release and issue TMA Q
|
||||
if (prefetch_q) {
|
||||
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
|
||||
issue_tma_q(q_stage_idx, q_idx + 1);
|
||||
}
|
||||
|
||||
// Read KV block index
|
||||
// TODO: deal with `-1`?
|
||||
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
|
||||
kv_block_idx_ptr = 0;
|
||||
kv_block_idx_storage = (kv_idx + kv_group_idx + + lane_idx * kNumMathWarpGroups < num_kv ?
|
||||
__ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
|
||||
}
|
||||
const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
|
||||
|
||||
// Wait KV consumer release
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
|
||||
|
||||
// Issue TMA KV
|
||||
if (cute::elect_one_sync()) {
|
||||
tma_3d_copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
||||
smem_kv[kv_stage_idx], 0, 0, kv_block_idx);
|
||||
tma_copy(&tensor_map_kv_scales, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
|
||||
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
|
||||
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fetch next task
|
||||
fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
|
||||
const auto& sub_warp_offset = (warp_idx % 4) * 16;
|
||||
const auto& v_0_offset = lane_idx / 4 + 0;
|
||||
const auto& v_1_offset = lane_idx / 4 + 8;
|
||||
|
||||
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
|
||||
uint32_t q_idx = batch_size, kv_idx;
|
||||
uint32_t next_q_idx, next_kv_idx, next_num_kv;
|
||||
uint32_t q_stage_idx, q_phase;
|
||||
|
||||
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
|
||||
// Current Q changes
|
||||
if (q_idx != next_q_idx) {
|
||||
// Release Last Q empty
|
||||
if (q_iter_idx > 0)
|
||||
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
|
||||
|
||||
// Wait TMA Q arrival
|
||||
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
|
||||
full_q_barriers[q_stage_idx]->wait(q_phase);
|
||||
|
||||
// Read weights
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextN; ++ i) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
|
||||
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
|
||||
}
|
||||
}
|
||||
|
||||
// Get current Q and KV index
|
||||
q_idx = next_q_idx;
|
||||
kv_idx = next_kv_idx;
|
||||
|
||||
// Calculate KV offset in advance
|
||||
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
|
||||
|
||||
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
|
||||
// Wait TMA KV arrival
|
||||
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
|
||||
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
|
||||
|
||||
// Issue WGMMA
|
||||
DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size");
|
||||
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
||||
auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
|
||||
WGMMA::wgmma(desc_a, desc_b, accum, k);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
|
||||
// Read per-KV scales
|
||||
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
|
||||
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
|
||||
|
||||
// Wait WGMMA
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Release KV empty
|
||||
empty_kv_barriers[kv_stage_idx]->arrive();
|
||||
|
||||
// Reduce over the head dim and store
|
||||
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation");
|
||||
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNextN; ++ i) {
|
||||
auto shifted_accum = accum + i * kNumAccumPerReduce;
|
||||
const auto& transform = [&](const uint32_t& j) {
|
||||
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
|
||||
};
|
||||
|
||||
// Intra-thread reduction
|
||||
float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
|
||||
#pragma unroll
|
||||
for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < 4; k ++)
|
||||
sum[k] += transform(j * 4 + k);
|
||||
}
|
||||
float v_0 = (sum[0] + sum[1]) * scale_kv_0;
|
||||
float v_1 = (sum[2] + sum[3]) * scale_kv_1;
|
||||
|
||||
// Inter-thread reduction
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; ++ j) {
|
||||
const auto& offset = static_cast<int>(1u << j);
|
||||
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
|
||||
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
|
||||
}
|
||||
|
||||
// Store into the global memory
|
||||
// NOTES: we have redundant writes here, consider more carefully
|
||||
logits[kv_offset + i * logits_stride + v_0_offset] = v_0;
|
||||
logits[kv_offset + i * logits_stride + v_1_offset] = v_1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
67
deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh
Normal file
67
deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh
Normal file
@@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps>
|
||||
__global__ __launch_bounds__(kNumWarps * 32, 1)
|
||||
void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv,
|
||||
const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) {
|
||||
const uint32_t& num_sms = gridDim.x;
|
||||
const uint32_t& sm_idx = blockIdx.x;
|
||||
const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
constexpr float neg_inf = -cute::numeric_limits<float>::infinity();
|
||||
|
||||
// Allocate filled `-inf` shared memory
|
||||
extern __shared__ __align__(1024) float smem_buffer[];
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
|
||||
smem_buffer[i] = neg_inf;
|
||||
cute::tma_store_fence();
|
||||
__syncthreads();
|
||||
|
||||
// Assign sequence to each warp
|
||||
const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx,
|
||||
const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
|
||||
const auto& per = total / num, rem = total % num;
|
||||
return {start + idx * per + min(idx, rem), per + (idx < rem)};
|
||||
};
|
||||
CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
|
||||
CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
|
||||
|
||||
if (cute::elect_one_sync()) {
|
||||
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
||||
const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
|
||||
const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
|
||||
const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
|
||||
|
||||
for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
|
||||
const auto& right = min(left + BLOCK_KV, static_cast<uint32_t>(stride_kv));
|
||||
if (right <= ks or ke <= left) {
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + left, (right - left) * sizeof(float));
|
||||
} else {
|
||||
if (left < aligned_ks)
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + left, (aligned_ks - left) * sizeof(float));
|
||||
if (aligned_ke < right)
|
||||
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + aligned_ke, (right - aligned_ke) * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
|
||||
const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
|
||||
const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
|
||||
const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
|
||||
for (uint32_t j = aligned_ks; j < ks; ++ j)
|
||||
logits[i * stride_kv + j] = neg_inf;
|
||||
for (uint32_t j = ke; j < aligned_ke; ++ j)
|
||||
logits[i * stride_kv + j] = neg_inf;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -55,6 +55,186 @@ def test_gemm_skip_head_mid() -> None:
|
||||
print()
|
||||
|
||||
|
||||
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
num_blocks, block_size, num_heads, head_dim = x.shape
|
||||
assert num_heads == 1
|
||||
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8)
|
||||
x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(dtype=torch.uint8)
|
||||
x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(dtype=torch.uint8)
|
||||
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
|
||||
|
||||
|
||||
def generate_cp_test_data(seq_len, seq_len_kv):
|
||||
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
|
||||
chunk_size = seq_len // 2
|
||||
cp_size = seq_len_kv // seq_len
|
||||
# Select an arbitrary CP rank
|
||||
cp_id = cp_size // 3
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device='cuda')
|
||||
ke = torch.zeros(seq_len, dtype=torch.int, device='cuda')
|
||||
for i in range(chunk_size):
|
||||
ke[i] = cp_id * chunk_size + i
|
||||
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
|
||||
return ks, ke
|
||||
|
||||
|
||||
def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False):
|
||||
seq_len_kv = kv.shape[0]
|
||||
|
||||
if cost_only:
|
||||
start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv)
|
||||
end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv)
|
||||
count_ones_per_row = (end - start).clamp(min=0)
|
||||
return count_ones_per_row.sum()
|
||||
|
||||
k = kv
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None]
|
||||
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None]
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum('mhd,nd->hmn', q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
|
||||
cost = mask.sum()
|
||||
return logits, cost
|
||||
|
||||
|
||||
def test_mqa_logits():
|
||||
print('Testing FP8 MQA Logits:')
|
||||
num_heads, head_dim = 64, 128
|
||||
for seq_len in (2048, 4096):
|
||||
for seq_len_kv in (4096, 8192, 16384, 32768, 65536, 131072):
|
||||
for disable_cp in (False, True):
|
||||
q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16)
|
||||
kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16)
|
||||
weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32)
|
||||
|
||||
if disable_cp:
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device='cuda')
|
||||
ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len)
|
||||
else:
|
||||
ks, ke = generate_cp_test_data(seq_len, seq_len_kv)
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
|
||||
logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
|
||||
do_check = (seq_len_kv < 32768)
|
||||
if do_check:
|
||||
ref_logits, ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
|
||||
|
||||
ref_neginf_mask = (ref_logits == float('-inf'))
|
||||
neginf_mask = (logits == float('-inf'))
|
||||
assert torch.equal(neginf_mask, ref_neginf_mask)
|
||||
|
||||
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
||||
logits = logits.masked_fill(neginf_mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
else:
|
||||
ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True)
|
||||
|
||||
tflops = 2 * ref_cost * num_heads * head_dim / 1e12
|
||||
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke),
|
||||
('fp8_mqa_logits', 'clean_logits'))
|
||||
clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke)
|
||||
print(f' > S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: '
|
||||
f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, '
|
||||
f'{(count_bytes(q_fp8, kv_fp8, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s | '
|
||||
f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor,
|
||||
max_model_len: int):
|
||||
batch_size, next_n, heads, dim = q.size()
|
||||
num_block, block_size, _, dim = kv_cache.size()
|
||||
logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32)
|
||||
context_lens = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens[i]
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device='cuda')
|
||||
weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
for block_rk in range(ceil_div(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda')
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None])
|
||||
s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf'))
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf'))
|
||||
return logits
|
||||
|
||||
|
||||
def test_paged_mqa_logits():
|
||||
print('Testing FP8 Paged MQA Logits:')
|
||||
max_model_len = 111 * 1000
|
||||
for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]:
|
||||
for heads, index_dim in [(64, 128)]:
|
||||
for avg_kv in (8192, 32768):
|
||||
num_blocks, blocksize = max_model_len * 3, 64
|
||||
|
||||
q = torch.randn((batch_size, next_n, heads, index_dim), device='cuda', dtype=torch.bfloat16)
|
||||
kv_cache = torch.randn((num_blocks, blocksize, 1, index_dim), device='cuda', dtype=torch.bfloat16)
|
||||
weights = torch.randn((batch_size * next_n, heads), device='cuda', dtype=torch.float32)
|
||||
|
||||
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size, )).cuda().to(torch.int32)
|
||||
max_block_len = (context_lens.max().item() + blocksize - 1) // blocksize * blocksize
|
||||
block_tables = torch.zeros((batch_size, max_block_len), device='cuda', dtype=torch.int32)
|
||||
|
||||
counter = 0
|
||||
block_idx_pool = list(range(num_blocks))
|
||||
random.shuffle(block_idx_pool)
|
||||
for i in range(batch_size):
|
||||
ctx_len = context_lens[i].item()
|
||||
for j in range(ceil_div(ctx_len, blocksize)):
|
||||
block_tables[i][j] = block_idx_pool[counter]
|
||||
counter += 1
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, 132)
|
||||
logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True)
|
||||
|
||||
ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len)
|
||||
positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1)
|
||||
row_indices = torch.arange(batch_size * next_n, device='cuda') // next_n
|
||||
next_n_offset = torch.arange(batch_size * next_n, device='cuda') % next_n
|
||||
ref_neginf_mask = ~(positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze(1))
|
||||
|
||||
neginf_mask = (logits == float('-inf'))
|
||||
assert torch.equal(neginf_mask, ref_neginf_mask)
|
||||
|
||||
logits = logits.masked_fill(neginf_mask, 0)
|
||||
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
|
||||
sum_lens = sum(context_lens.to(torch.int64))
|
||||
tflops = 2 * sum_lens * next_n * heads * index_dim / 1e12
|
||||
input_bytes = count_bytes(q_fp8, weights, context_lens) + sum_lens * (index_dim + 4) + (sum_lens / blocksize) * 4
|
||||
output_bytes = sum_lens * next_n * 4
|
||||
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True),
|
||||
('fp8_paged_mqa_logits', 'clean_logits'))
|
||||
clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens)
|
||||
print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: '
|
||||
f'{tflops / t:3.0f} TFLOPS, {t * 1e6:3.0f} us, '
|
||||
f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s | '
|
||||
f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
@@ -62,3 +242,7 @@ if __name__ == '__main__':
|
||||
random.seed(0)
|
||||
|
||||
test_gemm_skip_head_mid()
|
||||
|
||||
if get_arch_major() == 9:
|
||||
test_mqa_logits()
|
||||
test_paged_mqa_logits()
|
||||
|
||||
Reference in New Issue
Block a user