diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 2962eaf..eb03726 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -108,7 +108,7 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q, // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9) { + if (arch_major == 9 or arch_major == 10) { smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, seq_len, seq_len_kv, aligned_seq_len_kv, num_heads, head_dim, seq_len_alignment); } else { @@ -130,7 +130,7 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_ // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9) { + if (arch_major == 9 or arch_major == 10) { smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, block_kv, num_sms); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); @@ -205,7 +205,7 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); - if (arch_major == 9) { + if (arch_major == 9 or arch_major == 10) { smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups); diff --git a/csrc/indexing/main.cu b/csrc/indexing/main.cu index ed5a214..6419a85 100644 --- a/csrc/indexing/main.cu +++ b/csrc/indexing/main.cu @@ -9,6 +9,8 @@ // Attention kernels #include #include +#include +#include #include // Einsum kernels diff --git a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp index 4066036..38bbfb9 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp @@ -58,6 +58,7 @@ static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, // Calculate shared memory size const int smem_size = aligned_batch_size * static_cast(sizeof(int)); DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); // Launch const SMXXPagedMQALogitsMetadataRuntime::Args& args = { @@ -164,6 +165,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, const int& num_math_warp_groups) { const int num_specialized_threads = 128; const int num_math_threads = num_math_warp_groups * 128; + const int num_extra_threads = device_runtime->get_arch_major() == 10 ? 128 : 0; const int num_q_stages = 3, num_kv_stages = 3; const int split_kv = num_math_warp_groups * block_kv; DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0); @@ -183,6 +185,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, // Calculate shared memory size const int swizzle_alignment = head_dim * 8; + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); @@ -191,8 +194,13 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); - const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size; + // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); // Launch const SMXXFP8PagedMQALogitsRuntime::Args& args = { @@ -217,7 +225,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, .num_specialized_threads = num_specialized_threads, .num_math_threads = num_math_threads, .launch_args = LaunchArgs(num_sms, - num_specialized_threads + num_math_threads, + num_specialized_threads + num_math_threads + num_extra_threads, smem_size) }; const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh new file mode 100644 index 0000000..7890c26 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -0,0 +1,385 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_kv, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warp_in_group_idx = warp_idx % 4; + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + + // Align to 512 bytes for swizzle-64B + extern __shared__ __align__(512) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + + // TMA configs + constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + + // Tensor memory allocation + auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); + + // Initialize barriers + DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); + const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); + const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (is_umma_warp) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 32; + constexpr uint32_t kNumMathRegisters = 112; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + start = min(start, min(__ldg(cu_seq_len_k_start + q_idx), seq_len_kv)); + end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + // UMMA settings + // Construct instruction with layout F + constexpr uint32_t UMMA_M = 64; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + + if (is_tma_load_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue UMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } else if (warp_idx >= kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + float weights[BLOCK_Q][kNumHeads / 4]; + const auto& warp_offset = warp_idx * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset), + ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset)); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + // Load from the tensor memory + constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128; + uint32_t shifted_accum[kNumLDTMElems]; + DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM"); + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 16) { + cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + // Release UMMA empty + if (i == BLOCK_Q - 1) + empty_umma_barriers[warpgroup_idx]->arrive(); + + // Transform + const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) { + auto a = make_float2(fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k]), 0), + fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k + 2]), 0)); + auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]); + return __ffma2_rn(a, b, sum); + }; + + // Intra-thread reduction + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 8; ++ j) { + sum_0 = transform_2(j, 0, sum_0); + sum_1 = transform_2(j, 1, sum_1); + } + auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv); + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = 1u << j; + v.x += __shfl_xor_sync(0xffffffffu, v.x, offset); + v.y += __shfl_xor_sync(0xffffffffu, v.y, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + logits[q_idx * stride_kv + kv_offset + v_0_offset] = v.x; + logits[q_idx * stride_kv + kv_offset + v_1_offset] = v.y; + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + + // Free tensor memory + __syncthreads(); + if (is_tma_load_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh new file mode 100644 index 0000000..4a53421 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -0,0 +1,404 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads + 128, 1) +void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + + constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + + constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + + static constexpr uint32_t SMEM_UMMA_SIZE = kNumMathWarpGroups * 2 * 8 + static_cast(sizeof(uint32_t)); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q data and barriers on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + + // Separate math warpgroups and tma load warps into KV groups + // Each math warpgroup corresponds to a tma load warp + const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + + // Per group KV data and barriers on shared memory + const auto& smem_kv_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + + // UMMA barriers and TMEM pointer on shared memroy + auto umma_barrier_ptr = reinterpret_cast(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kNumMathWarpGroups); + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); + + constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + const bool& is_math_warp = (warp_idx < (kNumMathThreads / 32)); // 0 ~ 16 + const bool& is_tma_load_warp = (warp_idx >= (kNumMathThreads / 32) and warp_idx < (kNumMathThreads / 32 + 4)); // 16 ~ 20 + const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 4)); // 20 + + // Initialize barriers + if (is_tma_load_warp and cute::elect_one_sync()) { + if (kv_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + } + if (kv_group_idx < kNumMathWarpGroups) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(128); + } + } + cutlass::arch::fence_barrier_init(); + } + if (is_umma_warp) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 32; + constexpr uint32_t kNumMathRegisters = 104; + + // Scheduler + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + // UMMA settings + // Construct instruction with layout F + constexpr uint32_t UMMA_M = 64; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = kNextN * kNumHeads; + + if (is_tma_load_warp) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + if (kv_group_idx >= kNumMathWarpGroups) + return; + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (kv_group_idx == 0 and cute::elect_one_sync()) { + tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + kv_group_idx + + lane_idx * kNumMathWarpGroups < num_kv ? + __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + } + const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + if (cute::elect_one_sync()) { + tma_3d_copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + smem_kv[kv_stage_idx], 0, 0, kv_block_idx); + tma_copy(&tensor_map_kv_scales, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 0; + + auto smem_kv = PatternVisitor([&](const uint32_t& stage_idx) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_SIZE_PER_STAGE * stage_idx); + }); + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + if (q_idx != next_q_idx) { + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + } + + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + + DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(umma_phase & 1); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx] + i * SMEM_KV_PIPE_SIZE, 0, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + umma_phase ^= 1; + } + } else if (is_math_warp) { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + float weights[kNextN][kNumHeads / 4]; + const auto& sub_warp_offset = (warp_idx % 4) * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 0; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset), + ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset)); + + empty_umma_barriers[warpgroup_idx]->arrive(); + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait(umma_phase & 1); + umma_phase ^= 1; + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + // Load from the tensor memory + constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128; + uint32_t shifted_accum[kNumLDTMElems]; + DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM"); + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 16) { + cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + // Transform + const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) { + auto a = make_float2(fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k]), 0), + fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k + 2]), 0)); + auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]); + return __ffma2_rn(a, b, sum); + }; + + // Intra-thread reduction + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 8; ++ j) { + sum_0 = transform_2(j, 0, sum_0); + sum_1 = transform_2(j, 1, sum_1); + } + auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv); + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = 1u << j; + v.x += __shfl_xor_sync(0xffffffffu, v.x, offset); + v.y += __shfl_xor_sync(0xffffffffu, v.y, offset); + } + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + v_0_offset] = v.x; + logits[kv_offset + i * logits_stride + v_1_offset] = v.y; + } + } + } else { + cutlass::arch::warpgroup_reg_dealloc(); + } + + // Free tensor memory + __syncthreads(); + if (is_umma_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/tests/test_attention.py b/tests/test_attention.py index b3305ed..1baa80f 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -204,7 +204,7 @@ def test_paged_mqa_logits(): q_fp8 = q.to(torch.float8_e4m3fn) kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, 132) + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms()) logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True) ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len) @@ -229,7 +229,7 @@ def test_paged_mqa_logits(): ('fp8_paged_mqa_logits', 'clean_logits')) clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens) print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: ' - f'{tflops / t:3.0f} TFLOPS, {t * 1e6:3.0f} us, ' + f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, ' f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s | ' f'clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s') print() @@ -243,6 +243,5 @@ if __name__ == '__main__': test_gemm_skip_head_mid() - if get_arch_major() == 9: - test_mqa_logits() - test_paged_mqa_logits() + test_mqa_logits() + test_paged_mqa_logits()