diff --git a/csrc/cpu/cpu_attn_impl.hpp b/csrc/cpu/cpu_attn_impl.hpp index d2479e118..89cf2dc3a 100644 --- a/csrc/cpu/cpu_attn_impl.hpp +++ b/csrc/cpu/cpu_attn_impl.hpp @@ -1107,7 +1107,8 @@ class AttentionMainLoop { if (sliding_window_left != -1) { pos = std::max(pos, curr_token_pos - sliding_window_left); } - return pos; + // Clamp to tile end to avoid OOB when window starts past the tile + return std::min(pos, kv_tile_end_pos); }(); int32_t right_kv_pos = [&]() { diff --git a/csrc/cpu/cpu_attn_neon.hpp b/csrc/cpu/cpu_attn_neon.hpp index 827f0cfbc..3523893c3 100644 --- a/csrc/cpu/cpu_attn_neon.hpp +++ b/csrc/cpu/cpu_attn_neon.hpp @@ -4,6 +4,9 @@ #include "cpu_attn_impl.hpp" #include #include +#ifdef ARM_BF16_SUPPORT + #include "cpu_attn_neon_bfmmla.hpp" +#endif namespace cpu_attention { namespace { @@ -57,7 +60,7 @@ FORCE_INLINE void load_row8_B_as_f32(const c10::BFloat16* p, #endif } -// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with NEON FMLAs +// Mx8, with 1 <= M <= 8 , K streamed, unroll-by-4 with ASIMD FMLAs // #Loads = (K // 4) * (M + 4 * sizeof(kv_cache_t) / 2) // #FMLAs = (K // 4) * (4 * 2 * M) // We have (4 * 2 * M) FMLAs for (M + 4 * sizeof(kv_cache_t) / 2) loads @@ -381,6 +384,18 @@ class AttentionImpl { } } }; + +#ifdef ARM_BF16_SUPPORT +// For BF16 on Arm, reuse the BFMMLA kernels with 32-token alignment. +template +class AttentionImpl + : public AttentionImplNEONBFMMLA {}; +#endif } // namespace cpu_attention -#endif // #ifndef CPU_ATTN_NEON_HPP +#undef BLOCK_SIZE_ALIGNMENT +#undef HEAD_SIZE_ALIGNMENT +#undef MAX_Q_HEAD_NUM_PER_ITER + +#endif // #ifndef CPU_ATTN_ASIMD_HPP diff --git a/csrc/cpu/cpu_attn_neon_bfmmla.hpp b/csrc/cpu/cpu_attn_neon_bfmmla.hpp new file mode 100644 index 000000000..fb133aa13 --- /dev/null +++ b/csrc/cpu/cpu_attn_neon_bfmmla.hpp @@ -0,0 +1,682 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +#ifndef CPU_ATTN_NEON_BFMMLA_HPP +#define CPU_ATTN_NEON_BFMMLA_HPP + +#include "cpu_attn_impl.hpp" + +#include + +#include +#include + +namespace cpu_attention { + +namespace { + +// BFMMLA tile dimensions +constexpr int32_t TILE_ROWS = 2; // M dimension +constexpr int32_t TILE_K = 4; // K reduction +constexpr int32_t TILE_COLS = 2; // N dimension (column-pair) + +// Derived constants +constexpr int32_t OUTPUT_COLS_PER_BLOCK = 8; // 4 column-pairs +constexpr int32_t K_TOKENS_PER_GROUP = 8; // Tokens grouped in K cache +constexpr int32_t V_TOKENS_PER_ROW_BLOCK = 4; // Tokens per V cache row block +constexpr int32_t K_INNER_STRIDE = K_TOKENS_PER_GROUP * TILE_K; +constexpr int32_t V_INNER_STRIDE = V_TOKENS_PER_ROW_BLOCK * TILE_COLS; +constexpr int32_t PACK_ELEMENTS_PER_K_CHUNK = TILE_ROWS * TILE_K; // A packing + +// Matrix Packing and Accumulator +// Reshape two rows of Q into BFMMLA-friendly interleaved +// Input: row0 = [a0,a1,a2,a3], row1 = [b0,b1,b2,b3] +// Output: [a0,a1,a2,a3,b0,b1,b2,b3, a4,a5,a6,a7,b4,b5,b6,b7] +// For K tail (K % TILE_K != 0): pads with zeros to complete the final chunk +FORCE_INLINE void reshape_Q_2xK_for_bfmmla(const c10::BFloat16* __restrict r0, + const c10::BFloat16* __restrict r1, + c10::BFloat16* __restrict dst, + int32_t K) { + const uint16_t* s0 = reinterpret_cast(r0); + const uint16_t* s1 = reinterpret_cast(r1); + uint16_t* d = reinterpret_cast(dst); + + // Process TILE_K elements at a time (PACK_ELEMENTS_PER_K_CHUNK output) + int32_t k = 0; + for (; k + TILE_K <= K; k += TILE_K, d += PACK_ELEMENTS_PER_K_CHUNK) { + vst1q_u16(d, vcombine_u16(vld1_u16(s0 + k), vld1_u16(s1 + k))); + } + + // Handle K tail: pack remaining elements with zero-padding + const int32_t tail = K - k; + if (tail > 0) { + // Pack remaining tail elements: [r0[k..k+tail-1], pad, r1[k..k+tail-1], + // pad] + for (int32_t t = 0; t < tail; ++t) { + d[t] = s0[k + t]; + d[t + TILE_K] = s1[k + t]; + } + // Zero-pad the rest + for (int32_t t = tail; t < TILE_K; ++t) { + d[t] = 0; + d[t + TILE_K] = 0; + } + } +} + +// 2x2 accumulator load/store with compile-time row count +template +FORCE_INLINE float32x4_t load_acc_2x2(float* base, int64_t ldc, int col_off) { + static_assert(m_rows == 1 || m_rows == 2); + float32x2_t row0 = vld1_f32(base + col_off); + float32x2_t row1 = + (m_rows == 2) ? vld1_f32(base + ldc + col_off) : vdup_n_f32(0.f); + return vcombine_f32(row0, row1); +} + +template +FORCE_INLINE void store_acc_2x2(float32x4_t acc, float* base, int64_t ldc, + int col_off) { + static_assert(m_rows == 1 || m_rows == 2); + vst1_f32(base + col_off, vget_low_f32(acc)); + if constexpr (m_rows == 2) { + vst1_f32(base + ldc + col_off, vget_high_f32(acc)); + } +} + +// Initialize 4 column-pair accumulators for 2 rows (8 columns total) +#define INIT_ACC_ROWPAIR_4(a0, a1, a2, a3, Crow, ldc, m_rows, accum) \ + do { \ + if (accum) { \ + if (m_rows == 2) { \ + a0 = load_acc_2x2<2>(Crow, ldc, 0); \ + a1 = load_acc_2x2<2>(Crow, ldc, 2); \ + a2 = load_acc_2x2<2>(Crow, ldc, 4); \ + a3 = load_acc_2x2<2>(Crow, ldc, 6); \ + } else { \ + a0 = load_acc_2x2<1>(Crow, ldc, 0); \ + a1 = load_acc_2x2<1>(Crow, ldc, 2); \ + a2 = load_acc_2x2<1>(Crow, ldc, 4); \ + a3 = load_acc_2x2<1>(Crow, ldc, 6); \ + } \ + } else { \ + a0 = a1 = a2 = a3 = vdupq_n_f32(0.f); \ + } \ + } while (0) + +// Store 4 column-pair accumulators back to C matrix +#define STORE_ACC_ROWPAIR_4(a0, a1, a2, a3, Crow, ldc, m_rows) \ + do { \ + if (m_rows == 2) { \ + store_acc_2x2<2>(a0, Crow, ldc, 0); \ + store_acc_2x2<2>(a1, Crow, ldc, 2); \ + store_acc_2x2<2>(a2, Crow, ldc, 4); \ + store_acc_2x2<2>(a3, Crow, ldc, 6); \ + } else { \ + store_acc_2x2<1>(a0, Crow, ldc, 0); \ + store_acc_2x2<1>(a1, Crow, ldc, 2); \ + store_acc_2x2<1>(a2, Crow, ldc, 4); \ + store_acc_2x2<1>(a3, Crow, ldc, 6); \ + } \ + } while (0) + +// Perform 4 BFMMLA operations: acc += A @ B for 4 column-pairs +#define BFMMLA_COMPUTE_4(r0, r1, r2, r3, a, b0, b1, b2, b3) \ + do { \ + r0 = vbfmmlaq_f32(r0, a, b0); \ + r1 = vbfmmlaq_f32(r1, a, b1); \ + r2 = vbfmmlaq_f32(r2, a, b2); \ + r3 = vbfmmlaq_f32(r3, a, b3); \ + } while (0) + +// Micro-kernel: updates a small fixed tile using BFMMLA. +// RP = number of row-pairs (1,2,4) +// Computes C[TILE_ROWS*RP, OUTPUT_COLS_PER_BLOCK] += A_packed @ B. +// A_packed interleaves RP row-pairs; B layout is driven by the attention phase: +// - AttentionGemmPhase::QK -> token-column layout (Q @ K^T) +// - AttentionGemmPhase::PV -> token-row layout (P @ V) +// K_static < 0 enables runtime K (PV only) +template +FORCE_INLINE void gemm_rowpairs_x8_bfmmla_neon( + const bfloat16_t* const* __restrict A_packed_rp, + const int32_t* __restrict m_rows_rp, const bfloat16_t* __restrict B_blk, + float* __restrict C, int64_t ldc, bool accumulate, int64_t b_stride, + int32_t K_runtime = 0) { + static_assert(RP == 1 || RP == 2 || RP == 4, "RP must be 1,2,4"); + static_assert(K_static < 0 || K_static % TILE_K == 0, + "K must be divisible by TILE_K"); + static_assert(K_static >= 0 || phase == AttentionGemmPhase::PV, + "Runtime K only supported for PV"); + + constexpr bool runtime_k = (K_static < 0); + const int32_t K_iters = + runtime_k ? (K_runtime / TILE_K) : (K_static / TILE_K); + const int32_t K_tail = runtime_k ? (K_runtime % TILE_K) : 0; + + if (!runtime_k) { + // Help the compiler fold away unused K_runtime when K is compile-time + (void)K_runtime; + } + + auto* C_al = C; + const auto* B_al = B_blk; + + // Setup A pointers + const bfloat16_t* a_ptr[4] = { + A_packed_rp[0], + (RP >= 2) ? A_packed_rp[1] : nullptr, + (RP >= 4) ? A_packed_rp[2] : nullptr, + (RP >= 4) ? A_packed_rp[3] : nullptr, + }; + + // Setup B pointers based on layout + const bfloat16_t* b_ptr[4]; + if constexpr (phase == AttentionGemmPhase::PV) { + b_ptr[0] = B_blk + 0 * b_stride; + b_ptr[1] = B_blk + 1 * b_stride; + b_ptr[2] = B_blk + 2 * b_stride; + b_ptr[3] = B_blk + 3 * b_stride; + } + + float32x4_t acc[4][4]; + +// Initialize accumulators +#define INIT_RP(rp) \ + if constexpr (RP > rp) { \ + INIT_ACC_ROWPAIR_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], \ + C_al + (rp * 2) * ldc, ldc, m_rows_rp[rp], accumulate); \ + } + INIT_RP(0); + INIT_RP(1); + INIT_RP(2); + INIT_RP(3); +#undef INIT_RP + + // Main compute loop + for (int32_t ki = 0; ki < K_iters; ++ki) { + bfloat16x8_t b0, b1, b2, b3; + if constexpr (phase == AttentionGemmPhase::PV) { + b0 = vld1q_bf16(b_ptr[0] + ki * V_INNER_STRIDE); + b1 = vld1q_bf16(b_ptr[1] + ki * V_INNER_STRIDE); + b2 = vld1q_bf16(b_ptr[2] + ki * V_INNER_STRIDE); + b3 = vld1q_bf16(b_ptr[3] + ki * V_INNER_STRIDE); + } else { + const bfloat16_t* b_base = B_al + ki * b_stride; + b0 = vld1q_bf16(b_base + 0 * V_INNER_STRIDE); + b1 = vld1q_bf16(b_base + 1 * V_INNER_STRIDE); + b2 = vld1q_bf16(b_base + 2 * V_INNER_STRIDE); + b3 = vld1q_bf16(b_base + 3 * V_INNER_STRIDE); + } + +#define COMPUTE_RP(rp) \ + if constexpr (RP > rp) { \ + bfloat16x8_t a = vld1q_bf16(a_ptr[rp] + ki * PACK_ELEMENTS_PER_K_CHUNK); \ + BFMMLA_COMPUTE_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], a, b0, \ + b1, b2, b3); \ + } + COMPUTE_RP(0); + COMPUTE_RP(1); + COMPUTE_RP(2); + COMPUTE_RP(3); +#undef COMPUTE_RP + } + + // K tail for runtime PV: fallback path + if constexpr (runtime_k) { + if (K_tail > 0) { + const int32_t tail_offset = K_iters * V_INNER_STRIDE; + const int32_t a_tail_offset = K_iters * PACK_ELEMENTS_PER_K_CHUNK; + for (int32_t kt = 0; kt < K_tail; ++kt) { + float32x4_t b_vecs[4]; + for (int32_t p = 0; p < 4; ++p) { + const bfloat16_t* bp = b_ptr[p] + tail_offset + kt * TILE_COLS; + const float b0 = vcvtah_f32_bf16(bp[0]); + const float b1 = vcvtah_f32_bf16(bp[1]); + const float32x2_t b_pair = vset_lane_f32(b1, vdup_n_f32(b0), 1); + b_vecs[p] = vcombine_f32(b_pair, b_pair); + } + +#define TAIL_RP(rp) \ + if constexpr (RP > rp) { \ + const bfloat16_t* ap = A_packed_rp[rp] + a_tail_offset; \ + float a_row0 = vcvtah_f32_bf16(ap[kt]); \ + float a_row1 = \ + (m_rows_rp[rp] == 2) ? vcvtah_f32_bf16(ap[kt + TILE_K]) : 0.0f; \ + const float32x4_t a_vec = \ + vcombine_f32(vdup_n_f32(a_row0), vdup_n_f32(a_row1)); \ + for (int32_t p = 0; p < 4; ++p) { \ + acc[rp][p] = vmlaq_f32(acc[rp][p], a_vec, b_vecs[p]); \ + } \ + } + TAIL_RP(0); + TAIL_RP(1); + TAIL_RP(2); + TAIL_RP(3); +#undef TAIL_RP + } + } + } + + // Store results +#define STORE_RP(rp) \ + if constexpr (RP > rp) { \ + STORE_ACC_ROWPAIR_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], \ + C_al + (rp * 2) * ldc, ldc, m_rows_rp[rp]); \ + } + STORE_RP(0); + STORE_RP(1); + STORE_RP(2); + STORE_RP(3); +#undef STORE_RP +} + +// Meso-kernel: packs a small MBxK slice of A, then tiles over N and calls the +// micro-kernel for each OUTPUT_COLS_PER_BLOCK chunk. K_static < 0 enables +// runtime K (PV only). +template +FORCE_INLINE void gemm_packA_compute_MB_xN( + const c10::BFloat16* __restrict A, const c10::BFloat16* __restrict B, + float* __restrict C, int32_t K_runtime, int64_t lda, int64_t ldc, + int64_t b_layout_stride, int64_t b_reduction_stride, bool accumulate) { + static_assert(MB >= 1 && MB <= 8, "MB must be in [1,8]"); + static_assert(N % OUTPUT_COLS_PER_BLOCK == 0, + "N must be a multiple of OUTPUT_COLS_PER_BLOCK"); + static_assert(K_static < 0 || K_static % TILE_K == 0, + "K must be divisible by TILE_K"); + static_assert(K_static >= 0 || phase == AttentionGemmPhase::PV, + "Runtime K only supported for PV"); + + constexpr bool runtime_k = (K_static < 0); + const int32_t K_val = runtime_k ? K_runtime : K_static; + + // Keep small packs on-stack to avoid heap churn + constexpr int32_t STACK_PACK_STRIDE = + (1024 / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK; + + constexpr int32_t ROW_PAIRS = (MB + 1) / TILE_ROWS; + const int32_t pack_stride = + runtime_k ? ((K_val + TILE_K - 1) / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK + : (K_static / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK; + + alignas(64) c10::BFloat16 A_packed_stack[ROW_PAIRS * STACK_PACK_STRIDE]; + std::vector A_packed_heap; + c10::BFloat16* A_packed = + (pack_stride <= STACK_PACK_STRIDE) + ? A_packed_stack + : (A_packed_heap.resize(ROW_PAIRS * pack_stride), + A_packed_heap.data()); + + for (int32_t rp = 0; rp < ROW_PAIRS; ++rp) { + const int32_t m = rp * TILE_ROWS; + const int32_t m_rows = (m + 1 < MB) ? TILE_ROWS : 1; + const c10::BFloat16* A0 = A + m * lda; + const c10::BFloat16* A1 = (m_rows == TILE_ROWS) ? (A + (m + 1) * lda) : A0; + reshape_Q_2xK_for_bfmmla(A0, A1, A_packed + rp * pack_stride, K_val); + } + + for (int32_t n = 0; n < N; n += OUTPUT_COLS_PER_BLOCK) { + const c10::BFloat16* B_blk_c10 = + (phase == AttentionGemmPhase::PV) + ? (B + (n / TILE_COLS) * b_layout_stride) + : (B + (n / OUTPUT_COLS_PER_BLOCK) * b_layout_stride); + const bfloat16_t* B_blk = reinterpret_cast(B_blk_c10); + + // Process row-pairs in groups of 4, 2, then 1 + int32_t row_pair_idx = 0; + +#define PROCESS_RP_GROUP(group_size) \ + for (; row_pair_idx + (group_size - 1) < ROW_PAIRS; \ + row_pair_idx += group_size) { \ + const bfloat16_t* Ap[group_size]; \ + int32_t mr[group_size]; \ + for (int32_t i = 0; i < group_size; ++i) { \ + Ap[i] = reinterpret_cast( \ + A_packed + (row_pair_idx + i) * pack_stride); \ + mr[i] = (((row_pair_idx + i) * TILE_ROWS + 1) < MB) ? TILE_ROWS : 1; \ + } \ + float* C_blk = C + (row_pair_idx * TILE_ROWS) * ldc + n; \ + if constexpr (runtime_k) { \ + gemm_rowpairs_x8_bfmmla_neon( \ + Ap, mr, B_blk, C_blk, ldc, accumulate, b_layout_stride, K_val); \ + } else { \ + gemm_rowpairs_x8_bfmmla_neon( \ + Ap, mr, B_blk, C_blk, ldc, accumulate, \ + (phase == AttentionGemmPhase::PV) ? b_layout_stride \ + : b_reduction_stride); \ + } \ + } + + PROCESS_RP_GROUP(4); + PROCESS_RP_GROUP(2); + PROCESS_RP_GROUP(1); +#undef PROCESS_RP_GROUP + } +} + +// Macro-kernel: iterates over M in MB={8,4,2,1} chunks. +// Supports compile-time K specialization when K >= 0; otherwise uses runtime K +// (runtime K path is only supported for PV). +template +FORCE_INLINE void gemm_macro_neon_bfmmla( + const c10::BFloat16* __restrict A, const c10::BFloat16* __restrict B, + float* __restrict C, int32_t M, int32_t K_runtime, int64_t lda, int64_t ldc, + int64_t b_layout_stride, int64_t b_reduction_stride, bool accumulate) { + static_assert(N % OUTPUT_COLS_PER_BLOCK == 0, + "N must be a multiple of OUTPUT_COLS_PER_BLOCK"); + + if constexpr (K >= 0) { + static_assert(K % TILE_K == 0, "K must be divisible by TILE_K"); + for (int32_t m = 0; m < M;) { + const int32_t rem = M - m; + const c10::BFloat16* A_blk = A + m * lda; + float* C_blk = C + m * ldc; + +#define DISPATCH_MB(mb) \ + gemm_packA_compute_MB_xN(A_blk, B, C_blk, 0, lda, ldc, \ + b_layout_stride, \ + b_reduction_stride, accumulate) + + if (rem >= 8) { + DISPATCH_MB(8); + m += 8; + } else if (rem >= 4) { + DISPATCH_MB(4); + m += 4; + } else if (rem >= 2) { + DISPATCH_MB(2); + m += 2; + } else { + DISPATCH_MB(1); + m += 1; + } +#undef DISPATCH_MB + } + } else { + static_assert(phase == AttentionGemmPhase::PV, + "Runtime K specialization only supported for PV."); + const int32_t K_val = K_runtime; + + for (int32_t m = 0; m < M;) { + const int32_t rem = M - m; + const c10::BFloat16* A_blk = A + m * lda; + float* C_blk = C + m * ldc; + +#define DISPATCH_MB_RUNTIME(mb) \ + gemm_packA_compute_MB_xN(A_blk, B, C_blk, K_val, lda, ldc, \ + b_layout_stride, \ + b_reduction_stride, accumulate) + + if (rem >= 8) { + DISPATCH_MB_RUNTIME(8); + m += 8; + } else if (rem >= 4) { + DISPATCH_MB_RUNTIME(4); + m += 4; + } else if (rem >= 2) { + DISPATCH_MB_RUNTIME(2); + m += 2; + } else { + DISPATCH_MB_RUNTIME(1); + m += 1; + } +#undef DISPATCH_MB_RUNTIME + } + } +} + +#undef INIT_ACC_ROWPAIR_4 +#undef STORE_ACC_ROWPAIR_4 +#undef BFMMLA_COMPUTE_4 + +} // namespace + +// TileGemm Adapter for Attention + +template +class TileGemmNEONBFMMLA { + public: + template + FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, + kv_cache_t* __restrict__ b_tile, + float* __restrict__ c_tile, const int64_t lda, + [[maybe_unused]] const int64_t ldb, + const int64_t ldc, + [[maybe_unused]] const int32_t block_size, + [[maybe_unused]] const int32_t dynamic_k_size, + const bool accum_c) { + static_assert(BlockTokens % OUTPUT_COLS_PER_BLOCK == 0); + // BFMMLA kernels require compile-time head_dim; keep head_dim_ct only for + // API parity with other tile_gemm implementations. + if constexpr (head_dim_ct >= 0) { + static_assert(head_dim_ct == HeadDim, + "BFMMLA expects head_dim_ct to match HeadDim; PV passes " + "-1 for API parity."); + } + + if constexpr (phase == AttentionGemmPhase::QK) { + const int64_t b_reduction_stride = K_INNER_STRIDE; + const int64_t b_token_block_stride = (HeadDim / TILE_K) * K_INNER_STRIDE; + + gemm_macro_neon_bfmmla( + reinterpret_cast(a_tile), b_tile, c_tile, + m_size, 0, lda, ldc, b_token_block_stride, b_reduction_stride, + accum_c); + } else { + const int64_t b_pair_stride = + (block_size / V_TOKENS_PER_ROW_BLOCK) * V_INNER_STRIDE; + + // PV gemm with runtime K specialization + switch (dynamic_k_size) { + case 32: + gemm_macro_neon_bfmmla( + reinterpret_cast(a_tile), b_tile, c_tile, + m_size, 32, lda, ldc, b_pair_stride, 0, accum_c); + break; + case 128: + gemm_macro_neon_bfmmla( + reinterpret_cast(a_tile), b_tile, c_tile, + m_size, 128, lda, ldc, b_pair_stride, 0, accum_c); + break; + case 256: + gemm_macro_neon_bfmmla( + reinterpret_cast(a_tile), b_tile, c_tile, + m_size, 256, lda, ldc, b_pair_stride, 0, accum_c); + break; + default: + gemm_macro_neon_bfmmla( + reinterpret_cast(a_tile), b_tile, c_tile, + m_size, dynamic_k_size, lda, ldc, b_pair_stride, 0, accum_c); + break; + } + } + } +}; + +// Shared ASIMD BFMMLA implementation (BF16 only). The block size alignment and +// ISA tag are template parameters so we can reuse the same kernels for +// different NEON configurations. +template +class AttentionImplNEONBFMMLA { + public: + using query_t = c10::BFloat16; + using q_buffer_t = c10::BFloat16; + using kv_cache_t = c10::BFloat16; + using logits_buffer_t = float; + using partial_output_buffer_t = float; + using prob_buffer_t = c10::BFloat16; + + static constexpr int64_t BlockSizeAlignment = block_size_alignment; + // HeadDimAlignment equals head_dim so that the PV phase processes + // the full head dimension in a single gemm call. + static constexpr int64_t HeadDimAlignment = head_dim; + static constexpr int64_t MaxQHeadNumPerIteration = 16; + static constexpr int64_t HeadDim = head_dim; + static constexpr ISA ISAType = isa_type; + static constexpr bool scale_on_logits = false; + + static_assert(HeadDim % OUTPUT_COLS_PER_BLOCK == 0); + static_assert(BlockSizeAlignment % OUTPUT_COLS_PER_BLOCK == 0); + static_assert(HeadDim % TILE_K == 0, "HeadDim must be a multiple of TILE_K"); + + public: + template