Files
DeepGEMM/csrc/apis/attention.hpp
Zhean Xu 891d57b4db Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo

* Add Mega MoE Benchmark

* Minor fix

* Update

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2026-04-24 18:41:37 +08:00

454 lines
22 KiB
C++

#pragma once
#include "../utils/compatibility.hpp"
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_clean_logits.hpp"
#endif
#include "layout.hpp"
namespace deep_gemm::attention {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::tuple<int, int, int>& head_splits,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
if (fp8_requires_k_major()) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
}
// D must be N-major
check_major_type_cd(d);
// Type and shape checks
const auto [m , k ] = get_shape<2>(a.first);
const auto [n , k_] = get_shape<2>(b.first);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Check head splits and N
const auto [left, mid, right] = head_splits;
DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid);
// Do nothing if the problem is empty
if (m == 0)
return;
// Transform SFA and SFB into compute-required layout
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, disable_ue8m0_cast);
DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128);
// Dispatch into different implements
const auto arch_major = device_runtime->get_arch_major();
const auto epilogue_type = fmt::format("epilogue::transform::EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
const auto major_sfb = get_major_type_ab(sfb);
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
// NOTES: Only granularity 128 and FP8 are exposed in the API
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k,
128, 128, major_a, major_b, compiled_dims, epilogue_type);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::optional<torch::Tensor>>& q,
const std::tuple<torch::Tensor, torch::Tensor>& kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const bool& clean_logits,
const int& max_seqlen_k,
const at::ScalarType& logits_dtype) {
const auto [q_fp, q_sf] = q;
const auto [kv_fp, kv_sf] = kv;
const bool is_fp4 = q_sf.has_value();
int seq_len, seq_len_kv, num_heads, head_dim;
if (is_fp4) {
// Check FP4 Q
std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp);
head_dim *= 2;
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4);
// Check SF Q
auto [_seq_len, _num_heads] = get_shape<2>(q_sf.value());
DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads);
DG_HOST_ASSERT(q_sf.value().is_contiguous());
DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32);
// Check FP4 KV
int _head_dim;
std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp);
_head_dim *= 2;
DG_HOST_ASSERT(head_dim == _head_dim);
DG_HOST_ASSERT(kv_fp.is_contiguous());
DG_HOST_ASSERT(kv_fp.scalar_type() == kPackedFP4);
// Check SF KV
auto [_seq_len_kv] = get_shape<1>(kv_sf);
DG_HOST_ASSERT(seq_len_kv == _seq_len_kv);
DG_HOST_ASSERT(kv_sf.is_contiguous());
DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kInt32);
} else {
// Check FP8 Q
std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check FP4 KV
int _head_dim;
std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp);
DG_HOST_ASSERT(head_dim == _head_dim);
DG_HOST_ASSERT(kv_fp.is_contiguous());
DG_HOST_ASSERT(kv_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check SF KV
auto [_seq_len_kv] = get_shape<1>(kv_sf);
DG_HOST_ASSERT(seq_len_kv == _seq_len_kv);
DG_HOST_ASSERT(kv_sf.is_contiguous());
DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kFloat);
}
// Check weights
auto [_seq_len, _num_heads] = get_shape<2>(weights);
DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads);
DG_HOST_ASSERT(weights.stride(1) == 1);
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
// Check cu_seq_len_k_start
DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len);
DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous());
DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt);
// Check cu_seq_len_k_end
DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len);
DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous());
DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt);
// Allocate output
constexpr int block_qh = 128;
constexpr int block_kv = 256;
const int block_q = block_qh / num_heads;
DG_HOST_ASSERT(block_qh % num_heads == 0);
torch::Tensor logits;
int aligned_seq_len = align(seq_len, block_q), stride_logits;
if (max_seqlen_k == 0) {
// Logits stride must be 16-byte aligned
stride_logits = align(seq_len_kv + block_kv, 8);
logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype));
logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)});
} else {
stride_logits = align(max_seqlen_k, block_kv);
logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype));
logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, max_seqlen_k)});
DG_HOST_ASSERT(not clean_logits);
}
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (is_fp4 and arch_major == 10) {
sm100_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
smxx_fp8_mqa_logits(q_fp, kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
// Clean unfilled logits
if (clean_logits)
smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, stride_logits);
return logits;
}
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms, const std::optional<torch::Tensor>& indices) {
// NOTES: Only 2D context lens is supported for now
DG_HOST_ASSERT(context_lens.dim() == 2);
const bool is_context_lens_2d = true;
const int batch_size = context_lens.size(0);
const int next_n = context_lens.size(1);
const bool is_varlen = indices.has_value();
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
DG_HOST_ASSERT(context_lens.is_contiguous());
// Create metadata tensor
auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options());
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (is_varlen) {
const auto& indices_tensor = indices.value();
DG_HOST_ASSERT(arch_major == 10 and next_n == 1 and (block_kv == 64 or block_kv == 32));
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr<int>());
} else if (arch_major == 9 or arch_major == 10) {
DG_HOST_ASSERT(block_kv == 64 or (arch_major == 10 and block_kv == 32));
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, false, nullptr);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
return schedule_metadata;
}
static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, std::optional<torch::Tensor>>& q,
const torch::Tensor& fused_kv_cache,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits,
const at::ScalarType& logits_dtype,
const std::optional<torch::Tensor>& indices) {
const auto [q_fp, q_sf] = q;
const bool is_fp4 = q_sf.has_value();
torch::Tensor kv_cache, kv_cache_sf;
int batch_size, next_n, num_heads, head_dim;
int num_kv_blocks, block_kv;
int kv_cache_stride_bytes;
int block_table_stride = block_table.stride(0);
int num_sms = device_runtime->get_num_sms();
if (is_fp4) {
// Check FP4 Q
std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp);
head_dim *= 2;
DG_HOST_ASSERT(next_n >= 1);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4);
// Check SF Q
auto [_batch_size, _next_n, _num_heads] = get_shape<3>(q_sf.value());
DG_HOST_ASSERT(batch_size == _batch_size and next_n == _next_n and num_heads == _num_heads);
DG_HOST_ASSERT(q_sf.value().is_contiguous());
DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32);
// Check fused KV cache
int num_heads_kv, fp4_with_sf_bytes;
std::tie(num_kv_blocks, block_kv, num_heads_kv, fp4_with_sf_bytes) = get_shape<4>(fused_kv_cache);
DG_HOST_ASSERT(block_kv == 32 or block_kv == 64);
DG_HOST_ASSERT(num_heads_kv == 1 and fp4_with_sf_bytes == head_dim / 2 + static_cast<int>(sizeof(int)));
DG_HOST_ASSERT(fused_kv_cache.stride(1) == fp4_with_sf_bytes and fused_kv_cache.stride(3) == 1);
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
// Derive FP4 values and SF tensor
kv_cache_stride_bytes = fused_kv_cache.stride(0);
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(int) == 0);
kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim / 2},
{kv_cache_stride_bytes, head_dim / 2, 1},
torch::TensorOptions().dtype(kPackedFP4)
);
kv_cache_sf = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim / 2,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(int)), 1},
torch::TensorOptions().dtype(torch::kInt32)
);
} else {
// Check FP8 Q
std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp);
DG_HOST_ASSERT(next_n >= 1);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check fused KV cache
int num_heads_kv, head_dim_with_sf;
std::tie(num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf) = get_shape<4>(fused_kv_cache);
DG_HOST_ASSERT(block_kv == 32 or block_kv == 64);
DG_HOST_ASSERT(num_heads_kv == 1 and head_dim_with_sf == head_dim + static_cast<int>(sizeof(float)));
DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf and fused_kv_cache.stride(3) == 1);
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
// Derive FP8 values and SF tensor
kv_cache_stride_bytes = fused_kv_cache.stride(0);
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0);
kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim},
{kv_cache_stride_bytes, head_dim, 1},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn)
);
kv_cache_sf = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 1},
torch::TensorOptions().dtype(torch::kFloat32)
);
// Weights must be contiguous for FP8
DG_HOST_ASSERT(weights.is_contiguous());
}
// Check weights
auto [_batch_size_next_n, _num_heads] = get_shape<2>(weights);
DG_HOST_ASSERT(_batch_size_next_n == batch_size * next_n and _num_heads == num_heads);
DG_HOST_ASSERT(weights.stride(1) == 1);
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
// Check block table
auto [_batch_size, _max_block_len] = get_shape<2>(block_table);
DG_HOST_ASSERT(_batch_size == batch_size);
DG_HOST_ASSERT(block_table.stride(1) == 1);
DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt);
// Check indices
const bool is_varlen = indices.has_value();
const auto arch_major = device_runtime->get_arch_major();
const auto indices_tensor = indices.value_or(torch::Tensor());
if (is_varlen) {
DG_HOST_ASSERT(arch_major == 10 and next_n == 1);
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
}
// Check schedule metadata
auto [_schedule_meta_size, _meta_info_size] = get_shape<2>(schedule_meta);
DG_HOST_ASSERT(_schedule_meta_size == num_sms + 1 and _meta_info_size == 2);
DG_HOST_ASSERT(schedule_meta.is_contiguous());
DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt);
// Check context lengths
// NOTES: Only 2D context lens is supported for now
DG_HOST_ASSERT(context_lens.dim() == 2);
const bool is_context_lens_2d = true;
const auto [__batch_size, _next_n] = get_shape<2>(context_lens);
DG_HOST_ASSERT(batch_size == __batch_size and next_n == _next_n);
DG_HOST_ASSERT(context_lens.is_contiguous());
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
// Allocate output
constexpr int split_kv = 256;
const auto aligned_max_context_len = align(max_context_len, split_kv);
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q_fp.options().dtype(logits_dtype));
logits = logits.slice(-1, 0, max_context_len);
DG_HOST_ASSERT(logits_dtype == torch::kFloat32 or logits_dtype == torch::kBFloat16);
// Dispatch implementation
if (is_fp4 and arch_major == 10) {
sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
// Clean unfilled logits
if (clean_logits) {
DG_HOST_ASSERT(not is_context_lens_2d);
smxx_clean_logits(logits, std::nullopt, context_lens, next_n, batch_size * next_n, max_context_len, aligned_max_context_len);
}
return logits;
}
// Legacy API wrappers
static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
const std::tuple<torch::Tensor, torch::Tensor>& kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const bool& clean_logits,
const int& max_seqlen_k) {
return fp8_fp4_mqa_logits(std::make_tuple(q, std::nullopt), kv, weights,
cu_seq_len_k_start, cu_seq_len_k_end,
clean_logits, max_seqlen_k, torch::kFloat);
}
static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& fused_kv_cache,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits,
const std::optional<torch::Tensor>& indices) {
return fp8_fp4_paged_mqa_logits(std::make_tuple(q, std::nullopt), fused_kv_cache, weights,
context_lens, block_table, schedule_meta,
max_context_len, clean_logits, torch::kFloat, indices);
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_fp4_mqa_logits", &fp8_fp4_mqa_logits,
py::arg("q"), py::arg("kv"), py::arg("weights"),
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
py::arg("clean_logits") = true,
py::arg("max_seqlen_k") = 0,
py::arg("logits_dtype") = torch::kFloat32);
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"),
py::arg("indices") = std::nullopt);
m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
py::arg("max_context_len"),
py::arg("clean_logits") = false,
py::arg("logits_dtype") = torch::kFloat32,
py::arg("indices") = std::nullopt);
// Legacy API
m.def("fp8_mqa_logits", &fp8_mqa_logits,
py::arg("q"), py::arg("kv"), py::arg("weights"),
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
py::arg("clean_logits") = true,
py::arg("max_seqlen_k") = 0);
m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
py::arg("max_context_len"), py::arg("clean_logits") = false,
py::arg("indices") = std::nullopt);
#endif
}
} // namespace deep_gemm::attention