diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 286387b..2962eaf 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -4,6 +4,9 @@ #include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" #include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" #include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/smxx_fp8_mqa_logits.hpp" +#include "../jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp" +#include "../jit_kernels/impls/smxx_clean_logits.hpp" #include "layout.hpp" @@ -66,12 +69,172 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair& kv, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const bool& clean_logits) { + const auto& [seq_len, num_heads, head_dim] = get_shape<3>(q); + const auto& [seq_len_kv, head_dim_] = get_shape<2>(kv.first); + const auto& [seq_len_, num_heads_] = get_shape<2>(weights); + const auto& [seq_len_kv_] = get_shape<1>(kv.second); + + DG_HOST_ASSERT(seq_len == seq_len_); + DG_HOST_ASSERT(num_heads == num_heads_ and head_dim == head_dim_); + DG_HOST_ASSERT(seq_len_kv == seq_len_kv_); + DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len); + DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len); + + DG_HOST_ASSERT(q.is_contiguous() and kv.first.is_contiguous()); + DG_HOST_ASSERT(kv.second.is_contiguous()); + DG_HOST_ASSERT(weights.is_contiguous()); + DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous()); + DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous()); + + DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(kv.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(kv.second.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt); + DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt); + + constexpr int seq_len_alignment = 4; + constexpr int block_kv = 256; + const auto aligned_seq_len = align(seq_len, seq_len_alignment); + const auto aligned_seq_len_kv = align(seq_len_kv + block_kv, 4); + auto logits = torch::empty({aligned_seq_len, aligned_seq_len_kv}, q.options().dtype(torch::kFloat)); + logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)}); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, + seq_len, seq_len_kv, aligned_seq_len_kv, num_heads, head_dim, seq_len_alignment); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Clean unfilled logits + if (clean_logits) + smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, aligned_seq_len_kv); + return logits; +} + +static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) { + const auto& [batch_size] = get_shape<1>(context_lens); + DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); + DG_HOST_ASSERT(context_lens.is_contiguous()); + + auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options()); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, block_kv, num_sms); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + return schedule_metadata; +} + +static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& fused_kv_cache, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& max_context_len, + const bool& clean_logits) { + const auto& [batch_size, next_n, num_heads, head_dim] = get_shape<4>(q); + const auto& [num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf] = get_shape<4>(fused_kv_cache); + const auto& [batch_size_] = get_shape<1>(context_lens); + const auto& [batch_size_next_n, num_heads_] = get_shape<2>(weights); + const auto& [batch_size__, max_block_len] = get_shape<2>(block_table); + const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta); + const auto& num_sms = device_runtime->get_num_sms(); + const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0); + const auto& block_table_stride = block_table.stride(0); + + DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size__); + DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n); + DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1); + DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast(sizeof(float))); + DG_HOST_ASSERT(schedule_meta_size == num_sms + 1 and meta_info_size == 2); + + DG_HOST_ASSERT(next_n == 1 or next_n == 2); + DG_HOST_ASSERT(block_kv == 64); + + DG_HOST_ASSERT(q.is_contiguous()); + DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0); + DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf); + DG_HOST_ASSERT(fused_kv_cache.stride(2) == head_dim_with_sf); + DG_HOST_ASSERT(fused_kv_cache.stride(3) == 1); + DG_HOST_ASSERT(weights.is_contiguous()); + DG_HOST_ASSERT(context_lens.is_contiguous()); + DG_HOST_ASSERT(block_table.stride(1) == 1); + DG_HOST_ASSERT(schedule_meta.is_contiguous()); + + DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte); + DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); + DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt); + DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt); + + // Derive FP8 values and SF tensor from KV cache + const auto& kv_cache = torch::from_blob( + fused_kv_cache.data_ptr(), + {num_kv_blocks, block_kv, head_dim}, + {kv_cache_stride_bytes, head_dim, 1}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn) + ); + const auto& kv_cache_scales = torch::from_blob( + fused_kv_cache.data_ptr() + block_kv * head_dim, + {num_kv_blocks, block_kv}, + {kv_cache_stride_bytes / static_cast(sizeof(float)), 1}, + torch::TensorOptions().dtype(torch::kFloat32) + ); + + // Allocate output + constexpr int num_math_warp_groups = 4; + const auto& aligned_max_context_len = align(max_context_len, num_math_warp_groups * block_kv); + auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat)); + logits = logits.slice(-1, 0, max_context_len); + + // Dispatch implementation + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta, + batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, + kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Clean unfilled logits + if (clean_logits) + smxx_clean_logits(logits, std::nullopt, context_lens, next_n, batch_size * next_n, max_context_len, aligned_max_context_len); + return logits; +} + static void register_apis(pybind11::module_& m) { m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"), py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("fp8_mqa_logits", &fp8_mqa_logits, + py::arg("q"), py::arg("kv"), py::arg("weights"), + py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"), + py::arg("clean_logits") = true); + m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata, + py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms")); + m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits, + py::arg("q"), py::arg("kv_cache"), py::arg("weights"), + py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"), + py::arg("max_context_len"), py::arg("clean_logits") = false); } } // namespace deep_gemm::attention diff --git a/csrc/indexing/main.cu b/csrc/indexing/main.cu index 42c03a8..ed5a214 100644 --- a/csrc/indexing/main.cu +++ b/csrc/indexing/main.cu @@ -6,6 +6,11 @@ #include #include +// Attention kernels +#include +#include +#include + // Einsum kernels #include #include diff --git a/csrc/jit_kernels/impls/smxx_clean_logits.hpp b/csrc/jit_kernels/impls/smxx_clean_logits.hpp new file mode 100644 index 0000000..cdb472d --- /dev/null +++ b/csrc/jit_kernels/impls/smxx_clean_logits.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +class SMXXCleanLogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int next_n; + int seq_len; + int seq_len_kv; + uint64_t stride_kv; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + float* logits; + + int block_kv; + int num_warps; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&smxx_clean_logits< + {}, {}, {} + >); +}}; +)", args.next_n, args.block_kv, args.num_warps); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, static_cast(args.stride_kv), + args.cu_seq_len_k_start, args.cu_seq_len_k_end, args.logits + )); + } +}; + +static void smxx_clean_logits(const torch::Tensor& logits, + const std::optional& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const int& next_n, + const int& seq_len, const int& seq_len_kv, + const uint64_t &stride_kv) { + const int block_kv = 8192; + const int num_warps = 8; + const int smem_size = block_kv * sizeof(float); + + // Launch + const SMXXCleanLogitsRuntime::Args& args = { + .next_n = next_n, + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .stride_kv = stride_kv, + .cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr() : nullptr, + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .block_kv = block_kv, + .num_warps = num_warps, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_warps * 32, smem_size) + }; + const auto& code = SMXXCleanLogitsRuntime::generate(args); + const auto& runtime = compiler->build("smxx_clean_logits", code); + SMXXCleanLogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp new file mode 100644 index 0000000..f585639 --- /dev/null +++ b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90FP8MQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int seq_len; + int seq_len_kv; + int stride_kv; + int num_heads, head_dim; + int num_q_stages; + int num_kv_stages; + + int block_q; + int block_kv; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + float* logits; + float softmax_scale; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_mqa_logits< + {}, {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", arch, arch, + args.num_heads, args.head_dim, + args.block_q, args.block_kv, + args.num_q_stages, args.num_kv_stages, + args.num_specialized_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, static_cast(args.stride_kv), + args.cu_seq_len_k_start, args.cu_seq_len_k_end, + args.logits, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv, const torch::Tensor& kv_scales, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const torch::Tensor& logits, + const int& seq_len, const int& seq_len_kv, const int& stride_kv, + const int& num_heads, const int& head_dim, + const int& seq_len_alignment) { + constexpr int block_qh = 128; + constexpr int block_kv = 256; + constexpr int num_specialized_threads = 128; + constexpr int num_math_threads = 512; + constexpr int num_q_stages = 3, num_kv_stages = 3; + const int block_q = block_qh / num_heads; + DG_HOST_ASSERT(block_qh % num_heads == 0); + DG_HOST_ASSERT(seq_len_alignment % block_q == 0); + + // Construct TMAs + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads, + head_dim, block_qh, head_dim, head_dim); + const auto& tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv, + head_dim, block_kv, head_dim, head_dim); + // According to the driver API, the minimal alignment is 256 bytes + // So it is safe for us to do a 16-byte OOB + const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_scales, + get_tma_aligned_size(seq_len_kv, static_cast(kv_scales.element_size())), + 1, block_kv, 1, 0, 0); + const auto& tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len, + num_heads, block_q, num_heads, 0); + + // Calculate shared memory size + int smem_size = 0; + const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast(q.element_size()); + const int smem_weight_size_per_stage = block_q * num_heads * static_cast(weights.element_size()); + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv.element_size()); + const int kv_scale_size_per_stage = block_kv * static_cast(kv_scales.element_size()); + smem_size += num_q_stages * smem_q_size_per_stage; + smem_size += num_kv_stages * smem_kv_size_per_stage; + smem_size += num_q_stages * smem_weight_size_per_stage; + smem_size += num_kv_stages * kv_scale_size_per_stage; + smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8; + smem_size += 4; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + + // Launch + const SM90FP8MQALogitsRuntime::Args& args = { + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .stride_kv = stride_kv, + .num_heads = num_heads, .head_dim = head_dim, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .block_q = block_q, + .block_kv = block_kv, + .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr(), + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto& code = SM90FP8MQALogitsRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_mqa_logits", code); + SM90FP8MQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp new file mode 100644 index 0000000..4066036 --- /dev/null +++ b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp @@ -0,0 +1,228 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime { +public: + struct Args { + int aligned_batch_size; + int split_kv; + int num_sms; + + int batch_size; + int* context_lens; + int* schedule_metadata; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&smxx_paged_mqa_logits_metadata< + {}, {}, {} + >); +}}; +)", arch, args.aligned_batch_size, args.split_kv, args.num_sms); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.context_lens, + args.schedule_metadata + )); + } +}; + +static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, + const torch::Tensor& schedule_metadata, + const int& batch_size, const int& block_kv, const int& num_sms) { + constexpr int num_math_warpgroups = 4; + constexpr int num_threads = 32; + const int aligned_batch_size = align(batch_size, 32); + const int split_kv = block_kv * num_math_warpgroups; + + // Calculate shared memory size + const int smem_size = aligned_batch_size * static_cast(sizeof(int)); + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + + // Launch + const SMXXPagedMQALogitsMetadataRuntime::Args& args = { + .aligned_batch_size = aligned_batch_size, + .split_kv = split_kv, + .num_sms = num_sms, + .batch_size = batch_size, + .context_lens = context_lens.data_ptr(), + .schedule_metadata = schedule_metadata.data_ptr(), + .launch_args = LaunchArgs(1, num_threads, smem_size) + }; + const auto& code = SMXXPagedMQALogitsMetadataRuntime::generate(args); + const auto& runtime = compiler->build("smxx_paged_mqa_logits_metadata", code); + SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args); +} + +class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int batch_size; + int next_n; + int num_heads; + int head_dim; + int block_kv; + int block_table_stride; + int logits_stride; + + int num_q_stages; + int num_kv_stages; + int split_kv; + + int* context_lens; + float* logits; + int* block_table; + int* schedule_meta; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_paged_mqa_logits< + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {} + >); +}}; +)", arch, arch, + args.next_n, args.num_heads, + args.head_dim, args.block_kv, + args.num_q_stages, args.num_kv_stages, + args.split_kv, + args.num_specialized_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + static_cast(args.logits_stride), + static_cast(args.block_table_stride), + args.context_lens, args.logits, + args.block_table, args.schedule_meta, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv_cache, + const torch::Tensor& kv_cache_scales, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& logits, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& batch_size, const int& next_n, + const int& num_heads, const int& head_dim, + const int& num_kv_blocks, const int& block_kv, + const int& kv_cache_stride_bytes, + const int& logits_stride, + const int& block_table_stride, + const int& num_sms, + const int& num_math_warp_groups) { + const int num_specialized_threads = 128; + const int num_math_threads = num_math_warp_groups * 128; + const int num_q_stages = 3, num_kv_stages = 3; + const int split_kv = num_math_warp_groups * block_kv; + DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0); + + // Construct TMAs + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, + head_dim, next_n * num_heads, head_dim, head_dim); + const auto& tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks, + head_dim, block_kv, 1, + head_dim, kv_cache_stride_bytes, head_dim); + // TODO: use 1D TMA + const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks, + block_kv, 1, kv_cache_stride_bytes / static_cast(sizeof(float)), 0); + const auto& tensor_map_weights = make_tma_2d_desc(weights, next_n * num_heads, batch_size, + next_n * num_heads, 1, next_n * num_heads, 0); + + // Calculate shared memory size + const int swizzle_alignment = head_dim * 8; + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); + const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); + + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); + const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); + const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); + + const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + + // Launch + const SMXXFP8PagedMQALogitsRuntime::Args& args = { + .batch_size = batch_size, + .next_n = next_n, + .num_heads = num_heads, + .head_dim = head_dim, + .block_kv = block_kv, + .block_table_stride = block_table_stride, + .logits_stride = logits_stride, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .split_kv = split_kv, + .context_lens = context_lens.data_ptr(), + .logits = logits.data_ptr(), + .block_table = block_table.data_ptr(), + .schedule_meta = schedule_meta.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_sms, + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_paged_mqa_logits", code); + SMXXFP8PagedMQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 55d87da..da3403e 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -41,6 +41,10 @@ from deep_gemm_cpp import ( cublaslt_gemm_tn, cublaslt_gemm_tt, # Einsum kernels einsum, + # Attention kernels + fp8_mqa_logits, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, # Layout kernels transform_sf_into_required_layout ) diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh new file mode 100644 index 0000000..52f4be6 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -0,0 +1,317 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + + // Initialize barriers + const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 32; + constexpr uint32_t kNumMathRegisters = 112; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + start = min(start, min(__ldg(cu_seq_len_k_start + q_idx), seq_len_kv)); + end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // Only the first warp remains + if (not is_tma_load_warp) + return; + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& thread_idx = threadIdx.x % kNumMathThreads; + const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; + + const auto& warp_offset = warp_idx * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + logits[q_idx * stride_kv + kv_offset + v_0_offset] = v_0; + logits[q_idx * stride_kv + kv_offset + v_1_offset] = v_1; + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } +} + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh new file mode 100644 index 0000000..edba79d --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -0,0 +1,403 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t* context_lens, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = get_lane_idx(); + + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t& context_len = (k * 32 + lane_idx < batch_size ? __ldg(context_lens + k * 32 + lane_idx) : 0); + num_segs[k] = ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t q_idx = 0; + while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) + ++ q_idx; + const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } +} + +template +struct PagedMQALogitsScheduler { + uint32_t batch_size; + const uint32_t* context_lens; + + uint32_t current_q_idx, current_kv_idx; + uint32_t end_q_idx, end_kv_idx; + uint32_t current_num_kv; + + __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, + const uint32_t* context_lens, const uint32_t* schedule_meta) { + this->batch_size = batch_size; + this->context_lens = context_lens; + + const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); + const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); + current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumMathWarpGroups; + end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumMathWarpGroups; + + current_num_kv = current_q_idx < batch_size ? ceil_div(__ldg(this->context_lens + current_q_idx), BLOCK_KV) : 0; + } + + __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_idx = current_q_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (q_idx == end_q_idx and kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumMathWarpGroups; + if (current_kv_idx >= current_num_kv) { + ++ current_q_idx; + current_kv_idx = 0; + current_num_kv = current_q_idx < batch_size ? ceil_div(__ldg(this->context_lens + current_q_idx), BLOCK_KV) : 0; + } + + return true; + } + + __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { + return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; + } +}; + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + + constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + + constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q data and barriers on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + + // Separate math warpgroups and tma load warps into KV groups + // Each math warpgroup corresponds to a tma load warp + const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + + // Per group KV data and barriers on shared memory + const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + + // Initialize barriers + if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { + if (kv_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + } + if (kv_group_idx < kNumMathWarpGroups) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(128); + } + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 64; + constexpr uint32_t kNumMathRegisters = 104; + + // Scheduler + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + if (kv_group_idx >= kNumMathWarpGroups) + return; + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (kv_group_idx == 0 and cute::elect_one_sync()) { + tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + kv_group_idx + + lane_idx * kNumMathWarpGroups < num_kv ? + __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + } + const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + if (cute::elect_one_sync()) { + tma_3d_copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + smem_kv[kv_stage_idx], 0, 0, kv_block_idx); + tma_copy(&tensor_map_kv_scales, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; + const auto& sub_warp_offset = (warp_idx % 4) * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + + // Wait WGMMA + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + v_0_offset] = v_0; + logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + } + } + } +} + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh b/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh new file mode 100644 index 0000000..b5d24e7 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(kNumWarps * 32, 1) +void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv, + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { + const uint32_t& num_sms = gridDim.x; + const uint32_t& sm_idx = blockIdx.x; + const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + constexpr float neg_inf = -cute::numeric_limits::infinity(); + + // Allocate filled `-inf` shared memory + extern __shared__ __align__(1024) float smem_buffer[]; + #pragma unroll + for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) + smem_buffer[i] = neg_inf; + cute::tma_store_fence(); + __syncthreads(); + + // Assign sequence to each warp + const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto& per = total / num, rem = total % num; + return {start + idx * per + min(idx, rem), per + (idx < rem)}; + }; + CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); + CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + + if (cute::elect_one_sync()) { + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + + for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { + const auto& right = min(left + BLOCK_KV, static_cast(stride_kv)); + if (right <= ks or ke <= left) { + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + left, (right - left) * sizeof(float)); + } else { + if (left < aligned_ks) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + left, (aligned_ks - left) * sizeof(float)); + if (aligned_ke < right) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_kv + aligned_ke, (right - aligned_ke) * sizeof(float)); + } + } + } + } + + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + for (uint32_t j = aligned_ks; j < ks; ++ j) + logits[i * stride_kv + j] = neg_inf; + for (uint32_t j = ke; j < aligned_ke; ++ j) + logits[i * stride_kv + j] = neg_inf; + } +} + +} diff --git a/tests/test_attention.py b/tests/test_attention.py index 98ffa20..b3305ed 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -55,6 +55,186 @@ def test_gemm_skip_head_mid() -> None: print() +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8) + x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(dtype=torch.uint8) + x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(dtype=torch.uint8) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def generate_cp_test_data(seq_len, seq_len_kv): + assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 + chunk_size = seq_len // 2 + cp_size = seq_len_kv // seq_len + # Select an arbitrary CP rank + cp_id = cp_size // 3 + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.zeros(seq_len, dtype=torch.int, device='cuda') + for i in range(chunk_size): + ke[i] = cp_id * chunk_size + i + ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i + return ks, ke + + +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + cost = mask.sum() + return logits, cost + + +def test_mqa_logits(): + print('Testing FP8 MQA Logits:') + num_heads, head_dim = 64, 128 + for seq_len in (2048, 4096): + for seq_len_kv in (4096, 8192, 16384, 32768, 65536, 131072): + for disable_cp in (False, True): + q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32) + + if disable_cp: + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len) + else: + ks, ke = generate_cp_test_data(seq_len, seq_len_kv) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False) + logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) + + do_check = (seq_len_kv < 32768) + if do_check: + ref_logits, ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + ref_neginf_mask = (ref_logits == float('-inf')) + neginf_mask = (logits == float('-inf')) + assert torch.equal(neginf_mask, ref_neginf_mask) + + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + else: + ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True) + + tflops = 2 * ref_cost * num_heads * head_dim / 1e12 + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke), + ('fp8_mqa_logits', 'clean_logits')) + clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke) + print(f' > S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: ' + f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, ' + f'{(count_bytes(q_fp8, kv_fp8, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s | ' + f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s') + print() + + +def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, + weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, + max_model_len: int): + batch_size, next_n, heads, dim = q.size() + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.arange(context_len - next_n, context_len, device='cuda') + weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() + for block_rk in range(ceil_div(context_len, block_size)): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device='cuda') + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) + s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf')) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + return logits + + +def test_paged_mqa_logits(): + print('Testing FP8 Paged MQA Logits:') + max_model_len = 111 * 1000 + for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]: + for heads, index_dim in [(64, 128)]: + for avg_kv in (8192, 32768): + num_blocks, blocksize = max_model_len * 3, 64 + + q = torch.randn((batch_size, next_n, heads, index_dim), device='cuda', dtype=torch.bfloat16) + kv_cache = torch.randn((num_blocks, blocksize, 1, index_dim), device='cuda', dtype=torch.bfloat16) + weights = torch.randn((batch_size * next_n, heads), device='cuda', dtype=torch.float32) + + context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size, )).cuda().to(torch.int32) + max_block_len = (context_lens.max().item() + blocksize - 1) // blocksize * blocksize + block_tables = torch.zeros((batch_size, max_block_len), device='cuda', dtype=torch.int32) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = context_lens[i].item() + for j in range(ceil_div(ctx_len, blocksize)): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, 132) + logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True) + + ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len) + positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1) + row_indices = torch.arange(batch_size * next_n, device='cuda') // next_n + next_n_offset = torch.arange(batch_size * next_n, device='cuda') % next_n + ref_neginf_mask = ~(positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze(1)) + + neginf_mask = (logits == float('-inf')) + assert torch.equal(neginf_mask, ref_neginf_mask) + + logits = logits.masked_fill(neginf_mask, 0) + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + + sum_lens = sum(context_lens.to(torch.int64)) + tflops = 2 * sum_lens * next_n * heads * index_dim / 1e12 + input_bytes = count_bytes(q_fp8, weights, context_lens) + sum_lens * (index_dim + 4) + (sum_lens / blocksize) * 4 + output_bytes = sum_lens * next_n * 4 + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True), + ('fp8_paged_mqa_logits', 'clean_logits')) + clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens) + print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: ' + f'{tflops / t:3.0f} TFLOPS, {t * 1e6:3.0f} us, ' + f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s | ' + f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s') + print() + + if __name__ == '__main__': torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -62,3 +242,7 @@ if __name__ == '__main__': random.seed(0) test_gemm_skip_head_mid() + + if get_arch_major() == 9: + test_mqa_logits() + test_paged_mqa_logits()