#ifndef CPU_ATTN_VXE_HPP #define CPU_ATTN_VXE_HPP #include "cpu_attn_impl.hpp" #include #include namespace cpu_attention { namespace { // s390x Vector = 16 bytes (128 bits) #define BLOCK_SIZE_ALIGNMENT 32 #define HEAD_SIZE_ALIGNMENT 32 #define MAX_Q_HEAD_NUM_PER_ITER 16 template FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, __vector float& b0, __vector float& b1); // [1] Float Specialization template <> FORCE_INLINE void load_row8_B_as_f32(const float* p, __vector float& b0, __vector float& b1) { // Explicitly cast to long long for offset, and float* for pointer b0 = vec_xl((long long)0, const_cast(p)); b1 = vec_xl((long long)0, const_cast(p + 4)); } // [2] BFloat16 Specialization (Big Endian Fix) template <> FORCE_INLINE void load_row8_B_as_f32(const c10::BFloat16* p, __vector float& b0, __vector float& b1) { // 1. Load 8 BF16s (16 bytes) into one vector // Explicit cast to unsigned short* for vec_xl to return vector unsigned short __vector unsigned short raw = vec_xl((long long)0, (unsigned short*)p); // 2. Prepare Zero vector __vector unsigned short zeros = vec_splat_u16(0); // 3. Merge High/Low to expand BF16 -> Float32 // On Big Endian, a float is [BF16_bits | 16_zero_bits] b0 = (__vector float)vec_mergeh(raw, zeros); b1 = (__vector float)vec_mergel(raw, zeros); } template <> FORCE_INLINE void load_row8_B_as_f32(const c10::Half* p, __vector float& b0, __vector float& b1) { alignas(16) float tmp[8]; // Manual unroll / conversion tmp[0] = static_cast(p[0]); tmp[1] = static_cast(p[1]); tmp[2] = static_cast(p[2]); tmp[3] = static_cast(p[3]); tmp[4] = static_cast(p[4]); tmp[5] = static_cast(p[5]); tmp[6] = static_cast(p[6]); tmp[7] = static_cast(p[7]); // Explicit arguments for intrinsic: (long long offset, float* ptr) b0 = vec_xl((long long)0, (float*)tmp); b1 = vec_xl((long long)0, (float*)(tmp + 4)); } template FORCE_INLINE void gemm_micro_s390x_Mx8_Ku4( const float* __restrict A, // [M x K] const kv_cache_t* __restrict B, // [K x 8] float* __restrict C, // [M x 8] int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) { static_assert(1 <= M && M <= 8, "M must be in [1,8]"); // Helper macros to unroll codegen for M rows #define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7) #define IF_M(i) if constexpr (M > (i)) // 1. Define A pointers #define DECL_A(i) const float* a##i = A + (i) * lda; ROWS_APPLY(DECL_A) #undef DECL_A // 2. Define Accumulators (2 vectors covers 8 columns) #define DECL_ACC(i) __vector float acc##i##_0, acc##i##_1; ROWS_APPLY(DECL_ACC) #undef DECL_ACC // 3. Initialize Accumulators (Load C or Zero) #define INIT_ACC(i) \ IF_M(i) { \ if (accumulate) { \ acc##i##_0 = \ vec_xl((long long)0, const_cast(C + (i) * ldc + 0)); \ acc##i##_1 = \ vec_xl((long long)0, const_cast(C + (i) * ldc + 4)); \ } else { \ acc##i##_0 = vec_splats(0.0f); \ acc##i##_1 = vec_splats(0.0f); \ } \ } ROWS_APPLY(INIT_ACC) #undef INIT_ACC int32_t k = 0; for (; k + 3 < K; k += 4) { // Load 4 values of A for each Row M: A[k...k+3] #define LOAD_A4(i) \ __vector float a##i##v; \ IF_M(i) a##i##v = vec_xl((long long)0, const_cast(a##i + k)); ROWS_APPLY(LOAD_A4) #undef LOAD_A4 // Helper: FMA for specific lane L of A // s390x: vec_madd(b, vec_splat(a, lane), acc) #define FMAS_LANE(i, aiv, L) \ IF_M(i) { \ __vector float a_broad = vec_splat(aiv, L); \ acc##i##_0 = vec_madd(b0, a_broad, acc##i##_0); \ acc##i##_1 = vec_madd(b1, a_broad, acc##i##_1); \ } // Unroll K=0..3 { __vector float b0, b1; load_row8_B_as_f32(B + (int64_t)(k + 0) * ldb, b0, b1); #define STEP_K0(i) FMAS_LANE(i, a##i##v, 0) ROWS_APPLY(STEP_K0) #undef STEP_K0 } { __vector float b0, b1; load_row8_B_as_f32(B + (int64_t)(k + 1) * ldb, b0, b1); #define STEP_K1(i) FMAS_LANE(i, a##i##v, 1) ROWS_APPLY(STEP_K1) #undef STEP_K1 } { __vector float b0, b1; load_row8_B_as_f32(B + (int64_t)(k + 2) * ldb, b0, b1); #define STEP_K2(i) FMAS_LANE(i, a##i##v, 2) ROWS_APPLY(STEP_K2) #undef STEP_K2 } { __vector float b0, b1; load_row8_B_as_f32(B + (int64_t)(k + 3) * ldb, b0, b1); #define STEP_K3(i) FMAS_LANE(i, a##i##v, 3) ROWS_APPLY(STEP_K3) #undef STEP_K3 } #undef FMAS_LANE } for (; k < K; ++k) { __vector float b0, b1; load_row8_B_as_f32(B + (int64_t)k * ldb, b0, b1); #define TAIL_ROW(i) \ IF_M(i) { \ __vector float ai = vec_splats(*(a##i + k)); \ acc##i##_0 = vec_madd(b0, ai, acc##i##_0); \ acc##i##_1 = vec_madd(b1, ai, acc##i##_1); \ } ROWS_APPLY(TAIL_ROW) #undef TAIL_ROW } #define STORE_ROW(i) \ IF_M(i) { \ vec_xst(acc##i##_0, 0, C + (i) * ldc + 0); \ vec_xst(acc##i##_1, 0, C + (i) * ldc + 4); \ } ROWS_APPLY(STORE_ROW) #undef STORE_ROW #undef ROWS_APPLY #undef IF_M } template FORCE_INLINE void gemm_macro_s390x_Mx8_Ku4(const float* __restrict A, const kv_cache_t* __restrict B, float* __restrict C, int32_t M, int32_t K, int64_t lda, int64_t ldb, int64_t ldc, bool accumulate) { static_assert(N % 8 == 0, "N must be a multiple of 8"); for (int32_t m = 0; m < M;) { int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1; const float* Ab = A + m * lda; float* Cb = C + m * ldc; for (int32_t n = 0; n < N; n += 8) { const kv_cache_t* Bn = B + n; float* Cn = Cb + n; switch (mb) { case 8: gemm_micro_s390x_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K, accumulate); break; case 4: gemm_micro_s390x_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K, accumulate); break; case 2: gemm_micro_s390x_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K, accumulate); break; default: gemm_micro_s390x_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K, accumulate); break; } } m += mb; } } template class TileGemmS390X { public: template FORCE_INLINE static void gemm(const int32_t m_size, float* __restrict__ a_tile, kv_cache_t* __restrict__ b_tile, float* __restrict__ c_tile, const int64_t lda, const int64_t ldb, const int64_t ldc, const int32_t block_size, const int32_t dynamic_k_size, const bool accum_c) { if constexpr (phase == AttentionGemmPhase::QK) { gemm_macro_s390x_Mx8_Ku4( a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c); } else { gemm_macro_s390x_Mx8_Ku4( a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc, accum_c); } } }; } // namespace template class AttentionImpl { public: using query_t = scalar_t; using q_buffer_t = float; using kv_cache_t = scalar_t; using logits_buffer_t = float; using partial_output_buffer_t = float; using prob_buffer_t = float; constexpr static int64_t BlockSizeAlignment = BLOCK_SIZE_ALIGNMENT; constexpr static int64_t HeadDimAlignment = HEAD_SIZE_ALIGNMENT; constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER; constexpr static int64_t HeadDim = head_dim; constexpr static ISA ISAType = ISA::VXE; constexpr static bool scale_on_logits = false; // Scale is applied to Q during copy public: AttentionImpl() {} template