// SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright contributors to the vLLM project #ifndef CPU_ATTN_NEON_BFMMLA_HPP #define CPU_ATTN_NEON_BFMMLA_HPP #include "cpu_attn_impl.hpp" #include #include #include namespace cpu_attention { namespace { // BFMMLA tile dimensions constexpr int32_t TILE_ROWS = 2; // M dimension constexpr int32_t TILE_K = 4; // K reduction constexpr int32_t TILE_COLS = 2; // N dimension (column-pair) // Derived constants constexpr int32_t OUTPUT_COLS_PER_BLOCK = 8; // 4 column-pairs constexpr int32_t K_TOKENS_PER_GROUP = 8; // Tokens grouped in K cache constexpr int32_t V_TOKENS_PER_ROW_BLOCK = 4; // Tokens per V cache row block constexpr int32_t K_INNER_STRIDE = K_TOKENS_PER_GROUP * TILE_K; constexpr int32_t V_INNER_STRIDE = V_TOKENS_PER_ROW_BLOCK * TILE_COLS; constexpr int32_t PACK_ELEMENTS_PER_K_CHUNK = TILE_ROWS * TILE_K; // A packing // Matrix Packing and Accumulator // Reshape two rows of Q into BFMMLA-friendly interleaved // Input: row0 = [a0,a1,a2,a3], row1 = [b0,b1,b2,b3] // Output: [a0,a1,a2,a3,b0,b1,b2,b3, a4,a5,a6,a7,b4,b5,b6,b7] // For K tail (K % TILE_K != 0): pads with zeros to complete the final chunk FORCE_INLINE void reshape_Q_2xK_for_bfmmla(const c10::BFloat16* __restrict r0, const c10::BFloat16* __restrict r1, c10::BFloat16* __restrict dst, int32_t K) { const uint16_t* s0 = reinterpret_cast(r0); const uint16_t* s1 = reinterpret_cast(r1); uint16_t* d = reinterpret_cast(dst); // Process TILE_K elements at a time (PACK_ELEMENTS_PER_K_CHUNK output) int32_t k = 0; for (; k + TILE_K <= K; k += TILE_K, d += PACK_ELEMENTS_PER_K_CHUNK) { vst1q_u16(d, vcombine_u16(vld1_u16(s0 + k), vld1_u16(s1 + k))); } // Handle K tail: pack remaining elements with zero-padding const int32_t tail = K - k; if (tail > 0) { // Pack remaining tail elements: [r0[k..k+tail-1], pad, r1[k..k+tail-1], // pad] for (int32_t t = 0; t < tail; ++t) { d[t] = s0[k + t]; d[t + TILE_K] = s1[k + t]; } // Zero-pad the rest for (int32_t t = tail; t < TILE_K; ++t) { d[t] = 0; d[t + TILE_K] = 0; } } } // 2x2 accumulator load/store with compile-time row count template FORCE_INLINE float32x4_t load_acc_2x2(float* base, int64_t ldc, int col_off) { static_assert(m_rows == 1 || m_rows == 2); float32x2_t row0 = vld1_f32(base + col_off); float32x2_t row1 = (m_rows == 2) ? vld1_f32(base + ldc + col_off) : vdup_n_f32(0.f); return vcombine_f32(row0, row1); } template FORCE_INLINE void store_acc_2x2(float32x4_t acc, float* base, int64_t ldc, int col_off) { static_assert(m_rows == 1 || m_rows == 2); vst1_f32(base + col_off, vget_low_f32(acc)); if constexpr (m_rows == 2) { vst1_f32(base + ldc + col_off, vget_high_f32(acc)); } } // Initialize 4 column-pair accumulators for 2 rows (8 columns total) #define INIT_ACC_ROWPAIR_4(a0, a1, a2, a3, Crow, ldc, m_rows, accum) \ do { \ if (accum) { \ if (m_rows == 2) { \ a0 = load_acc_2x2<2>(Crow, ldc, 0); \ a1 = load_acc_2x2<2>(Crow, ldc, 2); \ a2 = load_acc_2x2<2>(Crow, ldc, 4); \ a3 = load_acc_2x2<2>(Crow, ldc, 6); \ } else { \ a0 = load_acc_2x2<1>(Crow, ldc, 0); \ a1 = load_acc_2x2<1>(Crow, ldc, 2); \ a2 = load_acc_2x2<1>(Crow, ldc, 4); \ a3 = load_acc_2x2<1>(Crow, ldc, 6); \ } \ } else { \ a0 = a1 = a2 = a3 = vdupq_n_f32(0.f); \ } \ } while (0) // Store 4 column-pair accumulators back to C matrix #define STORE_ACC_ROWPAIR_4(a0, a1, a2, a3, Crow, ldc, m_rows) \ do { \ if (m_rows == 2) { \ store_acc_2x2<2>(a0, Crow, ldc, 0); \ store_acc_2x2<2>(a1, Crow, ldc, 2); \ store_acc_2x2<2>(a2, Crow, ldc, 4); \ store_acc_2x2<2>(a3, Crow, ldc, 6); \ } else { \ store_acc_2x2<1>(a0, Crow, ldc, 0); \ store_acc_2x2<1>(a1, Crow, ldc, 2); \ store_acc_2x2<1>(a2, Crow, ldc, 4); \ store_acc_2x2<1>(a3, Crow, ldc, 6); \ } \ } while (0) // Perform 4 BFMMLA operations: acc += A @ B for 4 column-pairs #define BFMMLA_COMPUTE_4(r0, r1, r2, r3, a, b0, b1, b2, b3) \ do { \ r0 = vbfmmlaq_f32(r0, a, b0); \ r1 = vbfmmlaq_f32(r1, a, b1); \ r2 = vbfmmlaq_f32(r2, a, b2); \ r3 = vbfmmlaq_f32(r3, a, b3); \ } while (0) // Micro-kernel: updates a small fixed tile using BFMMLA. // RP = number of row-pairs (1,2,4) // Computes C[TILE_ROWS*RP, OUTPUT_COLS_PER_BLOCK] += A_packed @ B. // A_packed interleaves RP row-pairs; B layout is driven by the attention phase: // - AttentionGemmPhase::QK -> token-column layout (Q @ K^T) // - AttentionGemmPhase::PV -> token-row layout (P @ V) // K_static < 0 enables runtime K (PV only) template FORCE_INLINE void gemm_rowpairs_x8_bfmmla_neon( const bfloat16_t* const* __restrict A_packed_rp, const int32_t* __restrict m_rows_rp, const bfloat16_t* __restrict B_blk, float* __restrict C, int64_t ldc, bool accumulate, int64_t b_stride, int32_t K_runtime = 0) { static_assert(RP == 1 || RP == 2 || RP == 4, "RP must be 1,2,4"); static_assert(K_static < 0 || K_static % TILE_K == 0, "K must be divisible by TILE_K"); static_assert(K_static >= 0 || phase == AttentionGemmPhase::PV, "Runtime K only supported for PV"); constexpr bool runtime_k = (K_static < 0); const int32_t K_iters = runtime_k ? (K_runtime / TILE_K) : (K_static / TILE_K); const int32_t K_tail = runtime_k ? (K_runtime % TILE_K) : 0; if (!runtime_k) { // Help the compiler fold away unused K_runtime when K is compile-time (void)K_runtime; } auto* C_al = C; const auto* B_al = B_blk; // Setup A pointers const bfloat16_t* a_ptr[4] = { A_packed_rp[0], (RP >= 2) ? A_packed_rp[1] : nullptr, (RP >= 4) ? A_packed_rp[2] : nullptr, (RP >= 4) ? A_packed_rp[3] : nullptr, }; // Setup B pointers based on layout const bfloat16_t* b_ptr[4]; if constexpr (phase == AttentionGemmPhase::PV) { b_ptr[0] = B_blk + 0 * b_stride; b_ptr[1] = B_blk + 1 * b_stride; b_ptr[2] = B_blk + 2 * b_stride; b_ptr[3] = B_blk + 3 * b_stride; } float32x4_t acc[4][4]; // Initialize accumulators #define INIT_RP(rp) \ if constexpr (RP > rp) { \ INIT_ACC_ROWPAIR_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], \ C_al + (rp * 2) * ldc, ldc, m_rows_rp[rp], accumulate); \ } INIT_RP(0); INIT_RP(1); INIT_RP(2); INIT_RP(3); #undef INIT_RP // Main compute loop for (int32_t ki = 0; ki < K_iters; ++ki) { bfloat16x8_t b0, b1, b2, b3; if constexpr (phase == AttentionGemmPhase::PV) { b0 = vld1q_bf16(b_ptr[0] + ki * V_INNER_STRIDE); b1 = vld1q_bf16(b_ptr[1] + ki * V_INNER_STRIDE); b2 = vld1q_bf16(b_ptr[2] + ki * V_INNER_STRIDE); b3 = vld1q_bf16(b_ptr[3] + ki * V_INNER_STRIDE); } else { const bfloat16_t* b_base = B_al + ki * b_stride; b0 = vld1q_bf16(b_base + 0 * V_INNER_STRIDE); b1 = vld1q_bf16(b_base + 1 * V_INNER_STRIDE); b2 = vld1q_bf16(b_base + 2 * V_INNER_STRIDE); b3 = vld1q_bf16(b_base + 3 * V_INNER_STRIDE); } #define COMPUTE_RP(rp) \ if constexpr (RP > rp) { \ bfloat16x8_t a = vld1q_bf16(a_ptr[rp] + ki * PACK_ELEMENTS_PER_K_CHUNK); \ BFMMLA_COMPUTE_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], a, b0, \ b1, b2, b3); \ } COMPUTE_RP(0); COMPUTE_RP(1); COMPUTE_RP(2); COMPUTE_RP(3); #undef COMPUTE_RP } // K tail for runtime PV: fallback path if constexpr (runtime_k) { if (K_tail > 0) { const int32_t tail_offset = K_iters * V_INNER_STRIDE; const int32_t a_tail_offset = K_iters * PACK_ELEMENTS_PER_K_CHUNK; for (int32_t kt = 0; kt < K_tail; ++kt) { float32x4_t b_vecs[4]; for (int32_t p = 0; p < 4; ++p) { const bfloat16_t* bp = b_ptr[p] + tail_offset + kt * TILE_COLS; const float b0 = vcvtah_f32_bf16(bp[0]); const float b1 = vcvtah_f32_bf16(bp[1]); const float32x2_t b_pair = vset_lane_f32(b1, vdup_n_f32(b0), 1); b_vecs[p] = vcombine_f32(b_pair, b_pair); } #define TAIL_RP(rp) \ if constexpr (RP > rp) { \ const bfloat16_t* ap = A_packed_rp[rp] + a_tail_offset; \ float a_row0 = vcvtah_f32_bf16(ap[kt]); \ float a_row1 = \ (m_rows_rp[rp] == 2) ? vcvtah_f32_bf16(ap[kt + TILE_K]) : 0.0f; \ const float32x4_t a_vec = \ vcombine_f32(vdup_n_f32(a_row0), vdup_n_f32(a_row1)); \ for (int32_t p = 0; p < 4; ++p) { \ acc[rp][p] = vmlaq_f32(acc[rp][p], a_vec, b_vecs[p]); \ } \ } TAIL_RP(0); TAIL_RP(1); TAIL_RP(2); TAIL_RP(3); #undef TAIL_RP } } } // Store results #define STORE_RP(rp) \ if constexpr (RP > rp) { \ STORE_ACC_ROWPAIR_4(acc[rp][0], acc[rp][1], acc[rp][2], acc[rp][3], \ C_al + (rp * 2) * ldc, ldc, m_rows_rp[rp]); \ } STORE_RP(0); STORE_RP(1); STORE_RP(2); STORE_RP(3); #undef STORE_RP } // Meso-kernel: packs a small MBxK slice of A, then tiles over N and calls the // micro-kernel for each OUTPUT_COLS_PER_BLOCK chunk. K_static < 0 enables // runtime K (PV only). template FORCE_INLINE void gemm_packA_compute_MB_xN( const c10::BFloat16* __restrict A, const c10::BFloat16* __restrict B, float* __restrict C, int32_t K_runtime, int64_t lda, int64_t ldc, int64_t b_layout_stride, int64_t b_reduction_stride, bool accumulate) { static_assert(MB >= 1 && MB <= 8, "MB must be in [1,8]"); static_assert(N % OUTPUT_COLS_PER_BLOCK == 0, "N must be a multiple of OUTPUT_COLS_PER_BLOCK"); static_assert(K_static < 0 || K_static % TILE_K == 0, "K must be divisible by TILE_K"); static_assert(K_static >= 0 || phase == AttentionGemmPhase::PV, "Runtime K only supported for PV"); constexpr bool runtime_k = (K_static < 0); const int32_t K_val = runtime_k ? K_runtime : K_static; // Keep small packs on-stack to avoid heap churn constexpr int32_t STACK_PACK_STRIDE = (1024 / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK; constexpr int32_t ROW_PAIRS = (MB + 1) / TILE_ROWS; const int32_t pack_stride = runtime_k ? ((K_val + TILE_K - 1) / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK : (K_static / TILE_K) * PACK_ELEMENTS_PER_K_CHUNK; alignas(64) c10::BFloat16 A_packed_stack[ROW_PAIRS * STACK_PACK_STRIDE]; std::vector A_packed_heap; c10::BFloat16* A_packed = (pack_stride <= STACK_PACK_STRIDE) ? A_packed_stack : (A_packed_heap.resize(ROW_PAIRS * pack_stride), A_packed_heap.data()); for (int32_t rp = 0; rp < ROW_PAIRS; ++rp) { const int32_t m = rp * TILE_ROWS; const int32_t m_rows = (m + 1 < MB) ? TILE_ROWS : 1; const c10::BFloat16* A0 = A + m * lda; const c10::BFloat16* A1 = (m_rows == TILE_ROWS) ? (A + (m + 1) * lda) : A0; reshape_Q_2xK_for_bfmmla(A0, A1, A_packed + rp * pack_stride, K_val); } for (int32_t n = 0; n < N; n += OUTPUT_COLS_PER_BLOCK) { const c10::BFloat16* B_blk_c10 = (phase == AttentionGemmPhase::PV) ? (B + (n / TILE_COLS) * b_layout_stride) : (B + (n / OUTPUT_COLS_PER_BLOCK) * b_layout_stride); const bfloat16_t* B_blk = reinterpret_cast(B_blk_c10); // Process row-pairs in groups of 4, 2, then 1 int32_t row_pair_idx = 0; #define PROCESS_RP_GROUP(group_size) \ for (; row_pair_idx + (group_size - 1) < ROW_PAIRS; \ row_pair_idx += group_size) { \ const bfloat16_t* Ap[group_size]; \ int32_t mr[group_size]; \ for (int32_t i = 0; i < group_size; ++i) { \ Ap[i] = reinterpret_cast( \ A_packed + (row_pair_idx + i) * pack_stride); \ mr[i] = (((row_pair_idx + i) * TILE_ROWS + 1) < MB) ? TILE_ROWS : 1; \ } \ float* C_blk = C + (row_pair_idx * TILE_ROWS) * ldc + n; \ if constexpr (runtime_k) { \ gemm_rowpairs_x8_bfmmla_neon( \ Ap, mr, B_blk, C_blk, ldc, accumulate, b_layout_stride, K_val); \ } else { \ gemm_rowpairs_x8_bfmmla_neon( \ Ap, mr, B_blk, C_blk, ldc, accumulate, \ (phase == AttentionGemmPhase::PV) ? b_layout_stride \ : b_reduction_stride); \ } \ } PROCESS_RP_GROUP(4); PROCESS_RP_GROUP(2); PROCESS_RP_GROUP(1); #undef PROCESS_RP_GROUP } } // Macro-kernel: iterates over M in MB={8,4,2,1} chunks. // Supports compile-time K specialization when K >= 0; otherwise uses runtime K // (runtime K path is only supported for PV). template FORCE_INLINE void gemm_macro_neon_bfmmla( const c10::BFloat16* __restrict A, const c10::BFloat16* __restrict B, float* __restrict C, int32_t M, int32_t K_runtime, int64_t lda, int64_t ldc, int64_t b_layout_stride, int64_t b_reduction_stride, bool accumulate) { static_assert(N % OUTPUT_COLS_PER_BLOCK == 0, "N must be a multiple of OUTPUT_COLS_PER_BLOCK"); if constexpr (K >= 0) { static_assert(K % TILE_K == 0, "K must be divisible by TILE_K"); for (int32_t m = 0; m < M;) { const int32_t rem = M - m; const c10::BFloat16* A_blk = A + m * lda; float* C_blk = C + m * ldc; #define DISPATCH_MB(mb) \ gemm_packA_compute_MB_xN(A_blk, B, C_blk, 0, lda, ldc, \ b_layout_stride, \ b_reduction_stride, accumulate) if (rem >= 8) { DISPATCH_MB(8); m += 8; } else if (rem >= 4) { DISPATCH_MB(4); m += 4; } else if (rem >= 2) { DISPATCH_MB(2); m += 2; } else { DISPATCH_MB(1); m += 1; } #undef DISPATCH_MB } } else { static_assert(phase == AttentionGemmPhase::PV, "Runtime K specialization only supported for PV."); const int32_t K_val = K_runtime; for (int32_t m = 0; m < M;) { const int32_t rem = M - m; const c10::BFloat16* A_blk = A + m * lda; float* C_blk = C + m * ldc; #define DISPATCH_MB_RUNTIME(mb) \ gemm_packA_compute_MB_xN(A_blk, B, C_blk, K_val, lda, ldc, \ b_layout_stride, \ b_reduction_stride, accumulate) if (rem >= 8) { DISPATCH_MB_RUNTIME(8); m += 8; } else if (rem >= 4) { DISPATCH_MB_RUNTIME(4); m += 4; } else if (rem >= 2) { DISPATCH_MB_RUNTIME(2); m += 2; } else { DISPATCH_MB_RUNTIME(1); m += 1; } #undef DISPATCH_MB_RUNTIME } } } #undef INIT_ACC_ROWPAIR_4 #undef STORE_ACC_ROWPAIR_4 #undef BFMMLA_COMPUTE_4 } // namespace // TileGemm Adapter for Attention template class TileGemmNEONBFMMLA { public: template FORCE_INLINE static void gemm(const int32_t m_size, void* __restrict__ a_tile, kv_cache_t* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, [[maybe_unused]] const int64_t ldb, const int64_t ldc, [[maybe_unused]] const int32_t block_size, [[maybe_unused]] const int32_t dynamic_k_size, const bool accum_c) { static_assert(BlockTokens % OUTPUT_COLS_PER_BLOCK == 0); // BFMMLA kernels require compile-time head_dim; keep head_dim_ct only for // API parity with other tile_gemm implementations. if constexpr (head_dim_ct >= 0) { static_assert(head_dim_ct == HeadDim, "BFMMLA expects head_dim_ct to match HeadDim; PV passes " "-1 for API parity."); } if constexpr (phase == AttentionGemmPhase::QK) { const int64_t b_reduction_stride = K_INNER_STRIDE; const int64_t b_token_block_stride = (HeadDim / TILE_K) * K_INNER_STRIDE; gemm_macro_neon_bfmmla( reinterpret_cast(a_tile), b_tile, c_tile, m_size, 0, lda, ldc, b_token_block_stride, b_reduction_stride, accum_c); } else { const int64_t b_pair_stride = (block_size / V_TOKENS_PER_ROW_BLOCK) * V_INNER_STRIDE; // PV gemm with runtime K specialization switch (dynamic_k_size) { case 32: gemm_macro_neon_bfmmla( reinterpret_cast(a_tile), b_tile, c_tile, m_size, 32, lda, ldc, b_pair_stride, 0, accum_c); break; case 128: gemm_macro_neon_bfmmla( reinterpret_cast(a_tile), b_tile, c_tile, m_size, 128, lda, ldc, b_pair_stride, 0, accum_c); break; case 256: gemm_macro_neon_bfmmla( reinterpret_cast(a_tile), b_tile, c_tile, m_size, 256, lda, ldc, b_pair_stride, 0, accum_c); break; default: gemm_macro_neon_bfmmla( reinterpret_cast(a_tile), b_tile, c_tile, m_size, dynamic_k_size, lda, ldc, b_pair_stride, 0, accum_c); break; } } } }; // Shared ASIMD BFMMLA implementation (BF16 only). The block size alignment and // ISA tag are template parameters so we can reuse the same kernels for // different NEON configurations. template class AttentionImplNEONBFMMLA { public: using query_t = c10::BFloat16; using q_buffer_t = c10::BFloat16; using kv_cache_t = c10::BFloat16; using logits_buffer_t = float; using partial_output_buffer_t = float; using prob_buffer_t = c10::BFloat16; static constexpr int64_t BlockSizeAlignment = block_size_alignment; // HeadDimAlignment equals head_dim so that the PV phase processes // the full head dimension in a single gemm call. static constexpr int64_t HeadDimAlignment = head_dim; static constexpr int64_t MaxQHeadNumPerIteration = 16; static constexpr int64_t HeadDim = head_dim; static constexpr ISA ISAType = isa_type; static constexpr bool scale_on_logits = false; static_assert(HeadDim % OUTPUT_COLS_PER_BLOCK == 0); static_assert(BlockSizeAlignment % OUTPUT_COLS_PER_BLOCK == 0); static_assert(HeadDim % TILE_K == 0, "HeadDim must be a multiple of TILE_K"); public: template